In [None]:
import sys
import numpy as np
import os
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data as data
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from PIL import Image
from itertools import cycle

from google.colab import drive

In [None]:
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


# Utils

In [None]:
def get_files(folder, name_filter=None, extension_filter=None):
    """Helper function that returns the list of files in a specified folder
    with a specified extension.

    Keyword arguments:
    - folder (``string``): The path to a folder.
    - name_filter (```string``, optional): The returned files must contain
    this substring in their filename. Default: None; files are not filtered.
    - extension_filter (``string``, optional): The desired file extension.
    Default: None; files are not filtered

    """
    if not os.path.isdir(folder):
        raise RuntimeError("\"{0}\" is not a folder.".format(folder))

    # Filename filter: if not specified don't filter (condition always true);
    # otherwise, use a lambda expression to filter out files that do not
    # contain "name_filter"
    if name_filter is None:
        # This looks hackish...there is probably a better way
        name_cond = lambda filename: True
    else:
        name_cond = lambda filename: name_filter in filename

    # Extension filter: if not specified don't filter (condition always true);
    # otherwise, use a lambda expression to filter out files whose extension
    # is not "extension_filter"
    if extension_filter is None:
        # This looks hackish...there is probably a better way
        ext_cond = lambda filename: True
    else:
        ext_cond = lambda filename: filename.endswith(extension_filter)

    filtered_files = []

    # Explore the directory tree to get files that contain "name_filter" and
    # with extension "extension_filter"
    for path, _, files in os.walk(folder):
        files.sort()
        for file in files:
            if name_cond(file) and ext_cond(file):
                #full_path = os.path.join(path, file)
                filtered_files.append(file)

    return filtered_files



learning_map = {
  0 : 0,    # "unlabeled"
  1 : 0,     # "outlier" mapped to "unlabeled" --------------------------mapped
  10: 1,     # "car"
  11: 2,     # "bicycle"
  13: 5,     # "bus" mapped to "other-vehicle" --------------------------mapped
  15: 3,     # "motorcycle"
  16: 5,     # "on-rails" mapped to "other-vehicle" ---------------------mapped
  18: 4,     # "truck"
  20: 5,     # "other-vehicle"
  30: 6,     # "person"
  31: 7,     # "bicyclist"
  32: 8,     # "motorcyclist"
  40: 9,     # "road"
  44: 10,    # "parking"
  48: 11,    # "sidewalk"
  49: 12,    # "other-ground"
  50: 13,    # "building"
  51: 14,    # "fence"
  52: 0,     # "other-structure" mapped to "unlabeled" ------------------mapped
  60: 9,     # "lane-marking" to "road" ---------------------------------mapped
  70: 15,    # "vegetation"
  71: 16,    # "trunk"
  72: 17,    # "terrain"
  80: 18,    # "pole"
  81: 19,    # "traffic-sign"
  99: 0,     # "other-object" to "unlabeled" ----------------------------mapped
  252: 1,    # "moving-car" to "car" ------------------------------------mapped
  253: 7,    # "moving-bicyclist" to "bicyclist" ------------------------mapped
  254: 6,    # "moving-person" to "person" ------------------------------mapped
  255: 8,    # "moving-motorcyclist" to "motorcyclist" ------------------mapped
  256: 5,    # "moving-on-rails" mapped to "other-vehicle" --------------mapped
  257: 5,    # "moving-bus" mapped to "other-vehicle" -------------------mapped
  258: 4,    # "moving-truck" to "truck" --------------------------------mapped
  259: 5,    # "moving-other"-vehicle to "other-vehicle" ----------------mapped
}

class_weights = { # as a ratio with the total number of points
  0: 0.018889854628292943,
  1: 0.0002937197336781505,
  10: 0.040818519255974316,
  11: 0.00016609538710764618,
  13: 2.7879693665067774e-05,
  15: 0.00039838616015114444,
  16: 0.0,
  18: 0.0020633612104619787,
  20: 0.0016218197275284021,
  30: 0.00017698551338515307,
  31: 1.1065903904919655e-08,
  32: 5.532951952459828e-09,
  40: 0.1987493871255525,
  44: 0.014717169549888214,
  48: 0.14392298360372,
  49: 0.0039048553037472045,
  50: 0.1326861944777486,
  51: 0.0723592229456223,
  52: 0.002395131480328884,
  60: 4.7084144280367186e-05,
  70: 0.26681502148037506,
  71: 0.006035012012626033,
  72: 0.07814222006271769,
  80: 0.002855498193863172,
  81: 0.0006155958086189918,
  99: 0.009923127583046915,
  252: 0.001789309418528068,
  253: 0.00012709999297008662,
  254: 0.00016059776092534436,
  255: 3.745553104802113e-05,
  256: 0.0,
  257: 0.00011351574470342043,
  258: 0.00010157861367183268,
  259: 4.3840131989471124e-05,
}

total_weights = 0
mapped_class_weights = []
for cls in range(20):
  weight = 0
  for key in learning_map:
    if learning_map[key]==cls:
      weight+=class_weights[key]
      total_weights+=class_weights[key]
  mapped_class_weights.append(1/weight)

weights = np.asarray(mapped_class_weights)
weights = torch.from_numpy(weights)


class_encoding={ 
  "unlabeled" : 0,
  "car" :1,
  "bicycle" : 2,
  "motorcycle" : 3,
  "truck" : 4,
  "other-vehicle":5,
  "person":6,
  "bicyclist":7,
  "motorcyclist":8,
  "road":9,
  "parking":10,
  "sidewalk":11,
  "other-ground":12,
  "building":13,
  "fence":14,
  "vegetation":15,
  "trunk":16,
  "terrain":17,
  "pole":18,
  "traffic-sign":19,
}

def map_labels(mask, lr_map = learning_map):
  new_mask=np.zeros(mask.shape)
  for i in range(mask.shape[0]):
    for j in range(mask.shape[1]):
      new_mask[i,j]=lr_map[int(mask[i,j])]
  return new_mask



class RangeKitti(data.Dataset):  
  def __init__(self, root_dir, mode):

    # SPLIT IN TRAIN / WITH / WITHOUT LABELS
    train_with_labels = ['00', '01', '02', '03']
    train_without_labels = ['04', '05', '06', '07']
    test =  ['08', '09']

    self.root_dir = root_dir

    if mode=='train_with_labels':
      folders = train_with_labels

    elif mode=='train_without_labels':
      folders = train_without_labels

    elif mode=='test':
      folders = test

    else :
      print("unkown mode")

    self.files = []
    for folder in folders:
      files = get_files(self.root_dir+'/'+folder+'/range')
      files = [self.root_dir+'/'+folder+'/range/'+file for file in files]
      self.files.extend(files)
    #self.files = self.files[:20]

  def __len__(self):
    return len(self.files)

  def __getitem__(self, index):
    path = self.files[index]
    proj = np.fromfile(path,dtype=np.int32).reshape(64,1024,6)
    proj = proj.astype(np.float32)/1000

    mask = map_labels(proj[:,:,5])

    proj = torch.from_numpy(proj[:,:,0:5])
    proj = torch.transpose(proj, 0,2)
    proj = torch.transpose(proj, 1,2)
    return proj, torch.from_numpy(mask).long()


# Metrics

In [None]:
class Metric(object):
    """Base class for all metrics.
    From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
    """
    def reset(self):
        pass

    def add(self):
        pass

    def value(self):
        pass

class ConfusionMatrix(Metric):
    """Constructs a confusion matrix for a multi-class classification problems.
    Does not support multi-label, multi-class problems.
    Keyword arguments:
    - num_classes (int): number of classes in the classification problem.
    - normalized (boolean, optional): Determines whether or not the confusion
    matrix is normalized or not. Default: False.
    Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
    """

    def __init__(self, num_classes, normalized=False):
        super().__init__()

        self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
        self.normalized = normalized
        self.num_classes = num_classes
        self.reset()

    def reset(self):
        self.conf.fill(0)

    def add(self, predicted, target):
        """Computes the confusion matrix
        The shape of the confusion matrix is K x K, where K is the number
        of classes.
        Keyword arguments:
        - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
        predicted scores obtained from the model for N examples and K classes,
        or an N-tensor/array of integer values between 0 and K-1.
        - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
        ground-truth classes for N examples and K classes, or an N-tensor/array
        of integer values between 0 and K-1.
        """
        # If target and/or predicted are tensors, convert them to numpy arrays
        if torch.is_tensor(predicted):
            predicted = predicted.cpu().numpy()
        if torch.is_tensor(target):
            target = target.cpu().numpy()

        assert predicted.shape[0] == target.shape[0], \
            'number of targets and predicted outputs do not match'

        if np.ndim(predicted) != 1:
            assert predicted.shape[1] == self.num_classes, \
                'number of predictions does not match size of confusion matrix'
            predicted = np.argmax(predicted, 1)
        else:
            assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
                'predicted values are not between 0 and k-1'

        if np.ndim(target) != 1:
            assert target.shape[1] == self.num_classes, \
                'Onehot target does not match size of confusion matrix'
            assert (target >= 0).all() and (target <= 1).all(), \
                'in one-hot encoding, target values should be 0 or 1'
            assert (target.sum(1) == 1).all(), \
                'multi-label setting is not supported'
            target = np.argmax(target, 1)
        else:
            assert (target.max() < self.num_classes) and (target.min() >= 0), \
                'target values are not between 0 and k-1'

        # hack for bincounting 2 arrays together
        x = predicted + self.num_classes * target
        bincount_2d = np.bincount(
            x.astype(np.int32), minlength=self.num_classes**2)
        assert bincount_2d.size == self.num_classes**2
        conf = bincount_2d.reshape((self.num_classes, self.num_classes))

        self.conf += conf

    def value(self):
        """
        Returns:
            Confustion matrix of K rows and K columns, where rows corresponds
            to ground-truth targets and columns corresponds to predicted
            targets.
        """
        if self.normalized:
            conf = self.conf.astype(np.float32)
            return conf / conf.sum(1).clip(min=1e-12)[:, None]
        else:
            return self.conf



class IoU(Metric):
    """Computes the intersection over union (IoU) per class and corresponding
    mean (mIoU).
    Intersection over union (IoU) is a common evaluation metric for semantic
    segmentation. The predictions are first accumulated in a confusion matrix
    and the IoU is computed from it as follows:
        IoU = true_positive / (true_positive + false_positive + false_negative).
    Keyword arguments:
    - num_classes (int): number of classes in the classification problem
    - normalized (boolean, optional): Determines whether or not the confusion
    matrix is normalized or not. Default: False.
    - ignore_index (int or iterable, optional): Index of the classes to ignore
    when computing the IoU. Can be an int, or any iterable of ints.
    """

    def __init__(self, num_classes, normalized=False, ignore_index=None):
        super().__init__()
        self.conf_metric = ConfusionMatrix(num_classes, normalized)

        if ignore_index is None:
            self.ignore_index = None
        elif isinstance(ignore_index, int):
            self.ignore_index = (ignore_index,)
        else:
            try:
                self.ignore_index = tuple(ignore_index)
            except TypeError:
                raise ValueError("'ignore_index' must be an int or iterable")

    def reset(self):
        self.conf_metric.reset()

    def add(self, predicted, target):
        """Adds the predicted and target pair to the IoU metric.
        Keyword arguments:
        - predicted (Tensor): Can be a (N, K, H, W) tensor of
        predicted scores obtained from the model for N examples and K classes,
        or (N, H, W) tensor of integer values between 0 and K-1.
        - target (Tensor): Can be a (N, K, H, W) tensor of
        target scores for N examples and K classes, or (N, H, W) tensor of
        integer values between 0 and K-1.
        """
        # Dimensions check
        assert predicted.size(0) == target.size(0), \
            'number of targets and predicted outputs do not match'
        assert predicted.dim() == 3 or predicted.dim() == 4, \
            "predictions must be of dimension (N, H, W) or (N, K, H, W)"
        assert target.dim() == 3 or target.dim() == 4, \
            "targets must be of dimension (N, H, W) or (N, K, H, W)"

        # If the tensor is in categorical format convert it to integer format
        if predicted.dim() == 4:
            _, predicted = predicted.max(1)
        if target.dim() == 4:
            _, target = target.max(1)

        self.conf_metric.add(predicted.view(-1), target.view(-1))

    def value(self):
        """Computes the IoU and mean IoU.
        The mean computation ignores NaN elements of the IoU array.
        Returns:
            Tuple: (IoU, mIoU). The first output is the per class IoU,
            for K classes it's numpy.ndarray with K elements. The second output,
            is the mean IoU.
        """
        conf_matrix = self.conf_metric.value()
        if self.ignore_index is not None:
            for index in self.ignore_index:
                conf_matrix[:, self.ignore_index] = 0
                conf_matrix[self.ignore_index, :] = 0
        true_positive = np.diag(conf_matrix)
        false_positive = np.sum(conf_matrix, 0) - true_positive
        false_negative = np.sum(conf_matrix, 1) - true_positive

        # Just in case we get a division by 0, ignore/hide the error
        with np.errstate(divide='ignore', invalid='ignore'):
            iou = true_positive / (true_positive + false_positive + false_negative)

        return iou, np.nanmean(iou)

# ClassMix loss

In [None]:
def class_mix_loss(labeled_outputs, labels, pseudo_labeled_outputs, pseudo_labels, lambda_pseudo=0.5):
  loss = nn.CrossEntropyLoss(weight=weights.cuda().float())
  return loss(labeled_outputs, labels) + lambda_pseudo * loss(pseudo_labeled_outputs, pseudo_labels)

# Train

In [None]:
def train(model, train_loader_with_labels, train_loader_without_labels, optim, criterion, metric, iteration_loss=False):
    '''
    Training script: it allows the training of one epoch of the DNN.
    input:
      model: the model you want to train
      train_loader_with/without_labels: the dataloaders (the FIFO of data)
      which will be used with / without labels
      optim: the optimizer you use[13 19  3 14 15 11 12]
      criterion: the criterion you want to optimize
      metric: other criteria
      iteration_loss : boolean that allow you to print the loss
    output:
      epoch_loss: the loss of the full epoch
      metric.value(): the value of the other criteria
    '''  
    model.train()
    epoch_loss = 0.0
    metric.reset()

    # size
    train_loader_with_labels = iter(train_loader_with_labels)
    train_loader_without_labels = iter(train_loader_without_labels)

    if(len(train_loader_with_labels) > len(train_loader_without_labels)):
      train_loader_without_labels = cycle(train_loader_without_labels)
    if(len(train_loader_without_labels) > len(train_loader_with_labels)):
      train_loader_with_labels = cycle(train_loader_with_labels)

    for step, (batch_data_with_labels, batch_data_without_labels) in enumerate(zip(train_loader_with_labels, train_loader_without_labels)):
        # Get the inputs and labels
        input_with_labels = batch_data_with_labels[0].cuda()
        labels = batch_data_with_labels[1].cuda()
        input_without_labels = batch_data_without_labels[0].cuda()

        optim.zero_grad()


        # CLASSMIX PART (pseudo labels generation and mixing)
        # reshape
        input_without_labels_1 = input_without_labels[0].reshape(1, *tuple(input_without_labels[0].size())) 
        input_without_labels_2 = input_without_labels[1].reshape(1, *tuple(input_without_labels[1].size()))
        # forward
        pred_labels_1 = model(input_without_labels_1)
        pred_labels_2 = model(input_without_labels_2)
        pseudo_labels_1 = torch.max(pred_labels_1, dim=1)[1] # -> argmax
        pseudo_labels_2 = torch.max(pred_labels_2, dim=1)[1] # -> argmax
        # mixing
        classes_in_pseudo_labels_1 = torch.unique(pseudo_labels_1).tolist()
        sub_set_classes = np.random.choice(classes_in_pseudo_labels_1, \
                                           size=len(classes_in_pseudo_labels_1) // 2, replace=False) # sub set of detected classes of size number of classes // 2
        mask = torch.zeros_like(pseudo_labels_1)
        for c in sub_set_classes:
          mask[torch.where(pseudo_labels_1 == c)] = 1 # to keep classes in sub_set_classes

        mixed_input = mask * input_without_labels_1 + (1 - mask) * input_without_labels_2 # final mixed image
        mixed_pseudo_labels = mask * pseudo_labels_1 + (1 - mask) * pseudo_labels_2 # corresponding pseudo_labels
      

        # Forward (labeled / pseudo labeled)
        output_labeled = model(input_with_labels)
        output_pseudo_labeled = model(mixed_input)
        # print('input_with_labels : ', input_with_labels.size())
        # print('mixed_input : ', mixed_input.size())
        # print('output_labeled : ', output_labeled.size())
        # print('labels : ', labels.size())
        # print('output_pseudo_labeled : ', output_pseudo_labeled.size())
        # print('mixed_pseudo_labels : ', mixed_pseudo_labels.size())

        # Loss computation (combines labeled / pseudo labeled) 
        loss = criterion(output_labeled, labels, output_pseudo_labeled, mixed_pseudo_labels)

        # Backpropagation
        loss.backward()
        optim.step()

        # Keep track of loss for current epoch
        epoch_loss += loss.item()

        # Keep track of the evaluation metric
        output_concat = torch.cat((output_labeled, output_pseudo_labeled), dim=0) # concat labeled / pseudo labeled output
        labels_concat = torch.cat((labels, mixed_pseudo_labels), dim=0) # concat labels / pseudo labels
        metric.add(output_concat.detach(), labels_concat.detach())

        if iteration_loss:
            print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))

    return epoch_loss / len(train_loader_with_labels), metric.value()

def test(model, test_loader, criterion, metric, iteration_loss=False):
    '''
    Validation script: it allows the validationof the DNN.
    input:
      model: the DNN you want to train
      test_loader: the dataloader (the FIFO of data)
      criterion: the criterion you hav optimized
      metric: other criteria
      iteration_loss : boolean that allow you to print the loss
    output:
      epoch_loss: the loss of the full epoch
      metric.value(): the value of the other criteria
    '''  
    model.eval()
    epoch_loss = 0.0
    metric.reset()
    
    for step, batch_data in enumerate(test_loader):
        # Get the inputs and labels
        inputs = batch_data[0].cuda()
        labels = batch_data[1].cuda()

        with torch.no_grad():
            # Forward propagation
            outputs = model(inputs)

            # Loss computation
            loss = criterion(outputs, labels)

        # Keep track of loss for current epoch
        epoch_loss += loss.item()

        # Keep track of evaluation the metric
        metric.add(outputs.detach(), labels.detach())

        if iteration_loss:
            print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))
    for classe, res in zip(class_encoding, metric.value()[0]):
        print(f"[{classe}] : {res}")
    return epoch_loss / len(test_loader), metric.value()

# Backbone

In [None]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''

    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        self.bilinear = bilinear

        self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        if self.bilinear:
            x1 = F.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True)
        else:
            x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        # for padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, classes=20):
        super(UNet, self).__init__()
        self.inc = inconv(5, 32)

        self.down1 = down(32, 64)
        self.down2 = down(64, 128)
        self.down3 = down(128, 256)
        self.down4 = down(256, 256)

        self.up4 = up(512, 128)
        self.up3 = up(256, 64)
        self.up2 = up(128, 32)
        self.up1 = up(64, 32)

        self.outconv = outconv(32, classes)


    def forward(self, x):
        # please complete
        x0 = x
        
        x0 = self.inc(x0)

        d1 = self.down1(x0)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u4 = self.up4(d4, d3)
        u3 = self.up3(u4, d2)
        u2 = self.up2(u3, d1)
        u1 = self.up1(u2, x0)

        x = self.outconv(u1)

        return x

# Training

In [None]:
# MODEL
model = UNet(classes=20)

# DATA
data_path = 'range_dataset' # YOUSSEF
# data_path = 'drive/My Drive/Semantic_segmentation_IA321/data/sequences' # MAXIME


train_set_with_labels = RangeKitti(data_path, mode='train_with_labels')
train_set_without_labels = RangeKitti(data_path, mode='train_without_labels')
val_set = RangeKitti(data_path, mode='test')



train_with_labels_loader = data.DataLoader(
        train_set_with_labels,
        batch_size=2,
        shuffle=True,
        num_workers=2)

# BATCH SIZE MUST BE 2 FOR NO LABELS PART (HERE WE MIX ONLY 2 IMAGES)
train_without_labels_loader = data.DataLoader(
        train_set_without_labels,
        batch_size=2,
        shuffle=True,
        num_workers=2)

val_loader = data.DataLoader(
        val_set,
        batch_size=2,
        shuffle=True,
        num_workers=2)


model.cuda()

# here are the training parameters
learning_rate =1e-3
weight_decay=2e-4
lr_decay_epochs=20
lr_decay=0.1
nb_epochs=50


# ClassMix loss 
train_criterion = class_mix_loss
test_criterion = nn.CrossEntropyLoss(weight=weights.cuda().float())

# We build the optimizer
optimizer = optim.Adam(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay)

# Learning rate decay scheduler
lr_updater = lr_scheduler.StepLR(optimizer, lr_decay_epochs,
                                     lr_decay)

# Evaluation metric
ignore_index=[]
#ignore_index0 = list(class_encoding).index('unlabeled')
ignore_index.append(0)
metric = IoU(20, ignore_index=ignore_index)

# Start Training
best_miou = 0
train_loss_history_1 = []
val_loss_history_1 = []
train_miou_history_1 = []
val_miou_history_1 = []

for epoch in range( nb_epochs):
  print(">>>> [Epoch: {0:d}] Training".format(epoch))
  epoch_loss, (iou, miou) = train(model, train_with_labels_loader, train_without_labels_loader, optimizer, train_criterion, metric) 
  lr_updater.step()
  print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".format(epoch, epoch_loss, miou))
  train_miou=miou
  train_loss=epoch_loss

  if (epoch + 1) % 5 == 0 or epoch + 1 == nb_epochs:
    print(">>>> [Epoch: {0:d}] Validation".format(epoch))
    loss, (iou, miou) = test(model, val_loader, test_criterion, metric)
    print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".format(epoch, loss, miou))

    train_loss_history_1.append(train_loss)
    val_loss_history_1.append(loss)
    train_miou_history_1.append(train_miou)
    val_miou_history_1.append(miou)
    # Print per class IoU on last epoch or if best iou
    if epoch + 1 == nb_epochs or miou > best_miou:
      for key, class_iou in zip(class_encoding.keys(), iou):
        print("{0}: {1:.4f}".format(key, class_iou))
        # Save the model if it's the best thus far
        if miou > best_miou:
          print("\nBest model thus far. Saving...\n")
          best_miou = miou
          torch.save(model.state_dict(), "Unet_epoch{}.pt".format(epoch+1))


torch.save(model.state_dict(), "Unet_epoch{}_final.pt".format(nb_epochs))

print('train_loss_history_1', train_loss_history_1)
print('val_loss_history_1',val_loss_history_1)
print('train_miou_history_1',train_miou_history_1)
print('val_miou_history_1',val_miou_history_1)

>>>> [Epoch: 0] Training


KeyboardInterrupt: ignored