In [1]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib
import torch
import torch.utils.data as data
import torchnet as tnt
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix
import os
import json
import pickle as pkl
import pprint
import time


PATH_AUG = 'augs'
sys.path.append( PATH_AUG )
import augmentation
import cutmix
import cutout
import mixup
import random_shif
import windowwarp
import get_augs

In [2]:
print(torch.cuda.device_count())

1


In [3]:


def augs(BATCH_SIZE, SHAPE, DO_PROB, element_prob):

    mix_up = mixup.Mixup(batch_size = BATCH_SIZE,
        do_prob = DO_PROB,
        sequence_shape = SHAPE[1:],
        linear_mix_min = 0.1,
        linear_mix_max = 0.5)

    cut_mix = cutmix.Cutmix(batch_size = BATCH_SIZE,
            do_prob = DO_PROB,
            sequence_shape = SHAPE[1:],
            min_cutmix_len = SHAPE[1] // 2,
            max_cutmix_len = SHAPE[1],
            channel_replace_prob = element_prob,
            )

    cut_mix.batch = BATCH_SIZE
    cut_out = cutout.Cutout(
            batch_size = BATCH_SIZE,
            do_prob = DO_PROB,
            sequence_shape = SHAPE[1:],
            min_cutout_len = SHAPE[1] // 2,
            max_cutout_len = SHAPE[1],
            channel_drop_prob = element_prob,
    )
    
    return mix_up, cut_out, cut_mix

def batch_aug(x, y, mix_up, cut_out, cut_mix):
    example = {'input': x, 'target': y}
    example = cut_mix(example)
    example = cut_out(example)
    example = mix_up(example)
    x, y = example['input'], example['target']
    return x, y

mixUp, cutOut, cutMix = augs( 2, (2, 30, 10, 128, 128), 0.7, 0.5)

In [4]:
PATH_TO_PASTIS = './PASTIS'
PATH_TO_PAPS = './utae-paps/'
sys.path.append(PATH_TO_PAPS)

In [5]:
#pastis import
#from dataloader import PASTIS_Dataset
#from collate import pad_collate

#model import
import src.model_utils
from src.backbones.utae import UTAE
from src.learning.miou import *
from src.learning.weight_init import *
from src import utils

from src.dataset import *

In [6]:
#pastis function

cm = matplotlib.cm.get_cmap('tab20')
def_colors = cm.colors
cus_colors = ['k'] + [def_colors[i] for i in range(1,20)]+['w']
cmap = ListedColormap(colors = cus_colors, name='agri',N=21)

def get_rgb(x, batch_index=0, t_show=1):
    """Utility function to get a displayable rgb image 
    from a Sentinel-2 time series.
    """
    im = x['S2'][batch_index, t_show, [2,1,0]].cpu().numpy()
    mx = im.max(axis=(1,2))
    mi = im.min(axis=(1,2))   
    im = (im - mi[:,None,None])/(mx - mi)[:,None,None]
    im = im.swapaxes(0,2).swapaxes(0,1)
    im = np.clip(im, a_max=1, a_min=0)
    return im

In [7]:
def iterate( model, data_loader, criterion, config, optimizer=None, mode="train", do_augs=False, device=None ):
    loss_meter = tnt.meter.AverageValueMeter()
    iou_meter = IoU(
        num_classes=config[ 'num_classes' ],
        ignore_index=config[ 'ignore_index' ],
        cm_device=config[ 'device' ],
    )

    t_start = time.time()
    for i, batch in enumerate(data_loader):
        
        if do_augs:
            (x, d), y = batch
            print( x.shape )
            x, y = batch_aug( x[:, :30, :, :, :], y, mixUp, cutOut, cutMix )
            batch = (x, d), y
             
        if device is not None:
            batch = recursive_todevice(batch, device)
        (x, dates), y = batch
        y = y.long()

        if mode != "train":
            with torch.no_grad():
                out = model(x, batch_positions=dates)
        else:
            optimizer.zero_grad()
            out = model(x, batch_positions=dates)
        

        loss = criterion(out, y)
        if mode == "train":
            loss.backward()
            optimizer.step()
          
        #print( out.shape, y.shape )

        with torch.no_grad():
            pred = out.argmax(dim=1)
        iou_meter.add(pred, y)
        loss_meter.add(loss.item())

        if (i + 1) % config[ 'display_step' ] == 0:
            miou, acc = iou_meter.get_miou_acc()
            print(
                "Step [{}/{}], Loss: {:.4f}, Acc : {:.2f}, mIoU {:.2f}".format(
                    i + 1, len(data_loader), loss_meter.value()[0], acc, miou
                )
            )

    t_end = time.time()
    total_time = t_end - t_start
    print("Epoch time : {:.1f}s".format(total_time))
    miou, acc = iou_meter.get_miou_acc()
    metrics = {
        "{}_accuracy".format(mode): acc,
        "{}_loss".format(mode): loss_meter.value()[0],
        "{}_IoU".format(mode): miou,
        "{}_epoch_time".format(mode): total_time,
    }

    if mode == "test":
        return metrics, iou_meter.conf_metric.value()  # confusion matrix
    else:
        return metrics


def recursive_todevice(x, device):
    if isinstance(x, torch.Tensor):
        return x.to(device)
    elif isinstance(x, dict):
        return {k: recursive_todevice(v, device) for k, v in x.items()}
    else:
        return [recursive_todevice(c, device) for c in x]


def prepare_output(config):
    os.makedirs(config[ 'res_dir' ], exist_ok=True)
    for fold in range(1, 2):
        os.makedirs(os.path.join(config[ 'res_dir' ], "Fold_{}".format(fold)), exist_ok=True)


def checkpoint(fold, log, config):
    with open(
        os.path.join(config[ 'res_dir' ], "Fold_{}".format(fold), "trainlog.json"), "w"
    ) as outfile:
        json.dump(log, outfile, indent=4)


def save_results(fold, metrics, conf_mat, config):
    with open(
        os.path.join(config[ 'res_dir' ], "Fold_{}".format(fold), "test_metrics.json"), "w"
    ) as outfile:
        json.dump(metrics, outfile, indent=4)
    pkl.dump(
        conf_mat,
        open(
            os.path.join(config[ 'res_dir' ], "Fold_{}".format(fold), "conf_mat.pkl"), "wb"
        ),
    )


def overall_performance(config):
    cm = np.zeros((config[ 'num_classes' ], config[ 'num_classes' ]))
    for fold in range(1, 6):
        cm += pkl.load(
            open(
                os.path.join(config[ 'res_dir' ], "Fold_{}".format(fold), "conf_mat.pkl"),
                "rb",
            )
        )

    if config.ignore_index is not None:
        cm = np.delete(cm, config[ 'ignore_index' ], axis=0)
        cm = np.delete(cm, config[ 'ignore_index' ], axis=1)

    _, perf = confusion_matrix_analysis(cm)

    print("Overall performance:")
    print("Acc: {},  IoU: {}".format(perf["Accuracy"], perf["MACRO_IoU"]))

    with open(os.path.join(config[ 'res_dir' ], "overall.json"), "w") as file:
        file.write(json.dumps(perf, indent=4))



In [8]:
pad_collate = lambda x: utils.pad_collate(x, pad_value = 0)

fold_sequence = [[1, 2], [4], [5]]
fold = 0

pastis_train_dataset = PASTIS_Dataset(PATH_TO_PASTIS, folds = fold_sequence[ 0 ],  norm=True, target='semantic')
pastis_test_dataset  = PASTIS_Dataset(PATH_TO_PASTIS, folds = fold_sequence[ 1 ], norm=True, target='semantic')
pastis_eval_dataset = PASTIS_Dataset(PATH_TO_PASTIS, folds = fold_sequence[ 2 ], norm=True, target='semantic')

print( len(pastis_train_dataset) )
print( len(pastis_test_dataset) )
print( len(pastis_eval_dataset) )

train_loader = torch.utils.data.DataLoader(pastis_train_dataset, batch_size=2, collate_fn=pad_collate, shuffle=True)
test_loader = torch.utils.data.DataLoader(pastis_test_dataset, batch_size=2, collate_fn=pad_collate, shuffle=True)
evaluate_loader = torch.utils.data.DataLoader(pastis_eval_dataset, batch_size=2, collate_fn=pad_collate, shuffle=True)


Reading patch metadata . . .


  for pid, date_seq in dates.iteritems():


Done.
Dataset ready.
Reading patch metadata . . .


  for pid, date_seq in dates.iteritems():


Done.
Dataset ready.
Reading patch metadata . . .


  for pid, date_seq in dates.iteritems():


Done.
Dataset ready.
981
482
496


In [9]:
config = {}
config[ 'epoch' ] = 100
config[ 'num_classes' ] = 20
config[ 'val_after' ] = 0 #nb epoch avant premiere validation
config[ 'val_every' ] = 1 #nb epoch entre chaque validation
config[ 'device' ] = 'cuda'
config[ 'ignore_index' ] = -1
config[ 'display_step' ] = 50
config[ 'res_dir' ] = './results'
config[ 'model' ] = 'utae'

torch.cuda.empty_cache()
#torch.cuda.max_split_size_mb( 1024 )

model_config = {}

model = UTAE(
        input_dim = 10,
        encoder_widths = [ 64, 64, 64, 128 ],
        decoder_widths = [ 32, 32, 64, 128 ],
        out_conv = [ 32, 20 ],
        str_conv_k = 4,
        str_conv_s = 2,
        str_conv_p = 1,
        agg_mode = "att_group",
        encoder_norm = "group",
        n_head = 16,
        d_model = 256,
        d_k = 4,
        encoder = False,
        return_maps = False,
        pad_value = 0,
        padding_mode = "reflect",
    )

device = config[ 'device' ]

  T, 2 * (torch.arange(offset, offset + d).float() // 2) / d


In [10]:

prepare_output(config)

model = model.to( device )
model.apply( weight_init )

optimizer = torch.optim.Adam( model.parameters(), lr = 0.001 )

weights = torch.ones( config[ 'num_classes' ], device = device ).float()
criterion = torch.nn.CrossEntropyLoss( weight = weights )

trainlog = {}
best_mIoU = 0

for e in range( 1, config[ 'epoch' ] + 1 ):
    print( "EPOCH {}/{}".format( e, config[ 'epoch' ] ) )
    
    model.train()
    
    train_metrics = iterate( model, train_loader, criterion, config = config, optimizer = optimizer, mode = "train", do_augs=True, device = device )
    
    if e % config[ 'val_every' ] == 0 and e > config[ 'val_after' ]:
        print("Validation . . . ")
        model.eval()
        val_metrics = iterate(
            model,
            data_loader=evaluate_loader,
            criterion=criterion,
            config=config,
            optimizer=optimizer,
            mode="val",
            device=device,
        )

        print(
            "Loss {:.4f},  Acc {:.2f},  IoU {:.4f}".format(
                val_metrics["val_loss"],
                val_metrics["val_accuracy"],
                val_metrics["val_IoU"],
            )
        )

        trainlog[e] = {**train_metrics, **val_metrics}
        checkpoint(fold + 1, trainlog, config)
        if val_metrics["val_IoU"] >= best_mIoU:
            best_mIoU = val_metrics["val_IoU"]
            torch.save(
                {
                    "epoch": e,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                },
                os.path.join(
                    config[ 'res_dir' ], "Fold_{}".format(fold + 1), "model.pth.tar"
                ),
            )
    else:
        trainlog[e] = { **train_metrics }
        checkpoint( fold + 1, trainlog, config )



print("Testing best epoch . . .")
model.load_state_dict(
    torch.load(
            os.path.join(
            config[ 'res_dir' ], "Fold_{}".format(fold + 1), "model.pth.tar"
    )
    )["state_dict"]
)

model.eval()

test_metrics, conf_mat = iterate(
    model,
    data_loader=test_loader,
    criterion=criterion,
    config=config,
    optimizer=optimizer,
    mode="test",
    device=device,
)
print(
    "Loss {:.4f},  Acc {:.2f},  IoU {:.4f}".format(
        test_metrics["test_loss"],
        test_metrics["test_accuracy"],
        test_metrics["test_IoU"],
    )
)
save_results(1, test_metrics, conf_mat.cpu().numpy(), config)

EPOCH 1/100
torch.Size([2, 43, 10, 128, 128])


  time = torch.range(0, self.sequence_shape[0] - 1, dtype=torch.float32)
  time = torch.range(0, self.sequence_shape[0] - 1, dtype=torch.float32)


RuntimeError: HIP out of memory. Tried to allocate 240.00 MiB (GPU 0; 4.00 GiB total capacity; 359.12 MiB already allocated; 3.62 GiB free; 362.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_HIP_ALLOC_CONF