In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Decoder


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DeepLabHeadV3Plus(nn.Module):
    def __init__(self, num_classes, in_channels=768, low_level_channels=96,
                 aspp_dilate=[12, 24, 36]):
        super(DeepLabHeadV3Plus, self).__init__()
        self.project = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )

        self.aspp = ASPP(in_channels, aspp_dilate)

        self.classifier = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),  # Todo change accordingly to original image shape
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # added for fun
            nn.Conv2d(256, num_classes, 1)
        )
        self._init_weight()

    def forward(self, low_features, out):
        low_level_feature = self.project(low_features)
        output_feature = self.aspp(out)
        output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear',
                                       align_corners=False)
        return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),  # Todo error
            nn.ReLU(inplace=True))

    def forward(self, x):
        size = x.shape[-2:]
        x = super(ASPPPooling, self).forward(x)
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)))

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1), )

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)


## Teacher

In [3]:
import torch.nn as nn
from transformers import OneFormerForUniversalSegmentation


class SwinDeepLabV3Plus(nn.Module):
    def __init__(self, num_classes):
        super(SwinDeepLabV3Plus, self).__init__()

        # Swin Transformer as backbone - (pre-trained backbone)
        oneformer_swin_t = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
        self.backbone = oneformer_swin_t.model.pixel_level_module.encoder

        # DeepLabv3+ as classifier
        self.classifier = DeepLabHeadV3Plus(num_classes=133)

    def forward(self, x):
        # Forward pass through Swin Transformer
        f_maps = self.backbone(x)
        features = f_maps['feature_maps']
        low_features = features[0]
        out_backbone = features[3]

        # Forward pass through DeepLabV3+
        logits_output = self.classifier(low_features, out_backbone)

        preds = logits_output.softmax(dim=1)
        preds = preds.argmax(dim=1)

        return logits_output, preds


## Prepare dataloader

In [4]:
# from utils.prepare_dataset import prepare_data

# prepare_data()

## Visualization function

In [5]:
id2label = {
    0: "person",
    1: "bicycle",
    2: "car",
    3: "motorcycle",
    4: "airplane",
    5: "bus",
    6: "train",
    7: "truck",
    8: "boat",
    9: "traffic light",
    10: "fire hydrant",
    11: "stop sign",
    12: "parking meter",
    13: "bench",
    14: "bird",
    15: "cat",
    16: "dog",
    17: "horse",
    18: "sheep",
    19: "cow",
    20: "elephant",
    21: "bear",
    22: "zebra",
    23: "giraffe",
    24: "backpack",
    25: "umbrella",
    26: "handbag",
    27: "tie",
    28: "suitcase",
    29: "frisbee",
    30: "skis",
    31: "snowboard",
    32: "sports ball",
    33: "kite",
    34: "baseball bat",
    35: "baseball glove",
    36: "skateboard",
    37: "surfboard",
    38: "tennis racket",
    39: "bottle",
    40: "wine glass",
    41: "cup",
    42: "fork",
    43: "knife",
    44: "spoon",
    45: "bowl",
    46: "banana",
    47: "apple",
    48: "sandwich",
    49: "orange",
    50: "broccoli",
    51: "carrot",
    52: "hot dog",
    53: "pizza",
    54: "donut",
    55: "cake",
    56: "chair",
    57: "couch",
    58: "potted plant",
    59: "bed",
    60: "dining table",
    61: "toilet",
    62: "tv",
    63: "laptop",
    64: "mouse",
    65: "remote",
    66: "keyboard",
    67: "cell phone",
    68: "microwave",
    69: "oven",
    70: "toaster",
    71: "sink",
    72: "refrigerator",
    73: "book",
    74: "clock",
    75: "vase",
    76: "scissors",
    77: "teddy bear",
    78: "hair drier",
    79: "toothbrush",
    80: "banner",
    81: "blanket",
    82: "bridge",
    83: "cardboard",
    84: "counter",
    85: "curtain",
    86: "door-stuff",
    87: "floor-wood",
    88: "flower",
    89: "fruit",
    90: "gravel",
    91: "house",
    92: "light",
    93: "mirror-stuff",
    94: "net",
    95: "pillow",
    96: "platform",
    97: "playingfield",
    98: "railroad",
    99: "river",
    100: "road",
    101: "roof",
    102: "sand",
    103: "sea",
    104: "shelf",
    105: "snow",
    106: "stairs",
    107: "tent",
    108: "towel",
    109: "wall-brick",
    110: "wall-stone",
    111: "wall-tile",
    112: "wall-wood",
    113: "water-other",
    114: "window-blind",
    115: "window-other",
    116: "tree-merged",
    117: "fence-merged",
    118: "ceiling-merged",
    119: "sky-other-merged",
    120: "cabinet-merged",
    121: "table-merged",
    122: "floor-other-merged",
    123: "pavement-merged",
    124: "mountain-merged",
    125: "grass-merged",
    126: "dirt-merged",
    127: "paper-merged",
    128: "food-other-merged",
    129: "building-other-merged",
    130: "rock-merged",
    131: "wall-other-merged",
    132: "rug-merged"
  }

In [6]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import cm


def visualize_segmentation(segmentation_tensor, plot_save_path):
    # todo le labels vengono tagliate a destra e in basso
    # get all the unique numbers
    labels_ids = torch.unique(segmentation_tensor).tolist()
    print(labels_ids)

    # Map ids with RGB colors
    coco_color_map = {id: cm.viridis(index / len(labels_ids)) for index, id in enumerate(labels_ids)}

    # Map the class indices to RGB colors using NumPy vectorized operations
    segmented_image = np.zeros((segmentation_tensor.shape[0], segmentation_tensor.shape[1], 4), dtype=np.float32)
    class_indices = segmentation_tensor.long().cpu().numpy()

    mask = np.isin(class_indices, list(coco_color_map.keys()))
    segmented_image[mask] = [coco_color_map[class_index] for class_index in class_indices[mask]]

    # Create legend labels based on id2label mapping
    legend_labels = [id2label[class_id] for class_id in labels_ids]

    # Display the segmented image with legend
    plt.imshow(segmented_image)
    plt.axis('off')
    plt.title('Segmentation Map')

    # Adjust layout to prevent overlapping
    plt.tight_layout()

    handles = [mpatches.Patch(color=coco_color_map[label_id], label=id2label[label_id]) for label_id in labels_ids]

    # Create legend with class labels
    plt.legend(handles=handles, labels=legend_labels, loc='upper left', bbox_to_anchor=(1, 1))
    plt.savefig('/content/drive/MyDrive/' + plot_save_path)
    plt.show()


## Losses

In [7]:
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Optional
from torch.nn.modules.loss import _Loss

# from .functional import soft_dice_score

__all__ = ["DiceLoss"]


def soft_dice_score(output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0,
                    eps: float = 1e-7, dims=None) -> torch.Tensor:
    """

    :param output:
    :param target:
    :param smooth:
    :param eps:
    :return:

    Shape:
        - Input: :math:`(N, NC, *)` where :math:`*` means any number
            of additional dimensions
        - Target: :math:`(N, NC, *)`, same shape as the input
        - Output: scalar.

    """
    assert output.size() == target.size()
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)
    dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
    return dice_score


class DiceLoss(_Loss):
    def __init__(
            self,
            log_loss=False,
            from_logits=True,
            smooth: float = 1e-7,
            ignore_index=None,
            eps=1e-7,
    ):
        """

        :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
        :param from_logits: If True assumes input is raw logits
        :param smooth:
        :param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
        :param eps: Small epsilon for numerical stability
        """
        super(DiceLoss, self).__init__()

        self.from_logits = from_logits
        self.smooth = smooth
        self.eps = eps
        self.log_loss = log_loss

    def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
        """

        :param y_pred: NxCxHxW
        :param y_true: NxHxW
        :return: scalar
        """
        assert y_true.size(0) == y_pred.size(0)

        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Log-Exp gives more stable result and does not cause vanishing gradient on extreme values 0 and 1
            y_pred = y_pred.log_softmax(dim=1).exp()

        bs = y_true.size(0)
        num_classes = y_pred.size(1)
        dims = (0, 2)

        y_true = y_true.view(bs, -1)
        y_pred = y_pred.view(bs, num_classes, -1)

        y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C
        y_true = y_true.permute(0, 2, 1)  # H, C, H*W

        scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)

        if self.log_loss:
            loss = -torch.log(scores.clamp_min(self.eps))
        else:
            loss = 1.0 - scores

        return loss.mean()



def softmax_focal_loss_with_logits(
    output: torch.Tensor,
    target: torch.Tensor,
    gamma: float = 2.0,
    reduction: str = "mean",
    normalized: bool = False,
    reduced_threshold: Optional[float] = None,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Softmax version of focal loss between target and output logits.
    See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.

    Args:
        output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss)
        target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss)
        gamma: Focal loss power factor
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`.
            'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
        normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
        reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
    """
    log_softmax = F.log_softmax(output, dim=1)

    loss = F.nll_loss(log_softmax, target, reduction="none")
    pt = torch.exp(-loss)

    # compute the loss
    if reduced_threshold is None:
        focal_term = (1.0 - pt).pow(gamma)
    else:
        focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
        focal_term[pt < reduced_threshold] = 1

    loss = focal_term * loss

    if normalized:
        norm_factor = focal_term.sum().clamp_min(eps)
        loss = loss / norm_factor

    if reduction == "mean":
        loss = loss.mean()
    if reduction == "sum":
        loss = loss.sum()
    if reduction == "batchwise_mean":
        loss = loss.sum(0)

    return loss


class CrossEntropyFocalLoss(nn.Module):
    """
    Focal loss for multi-class problem. It uses softmax to compute focal term instead of sigmoid as in
    original paper. This loss expects target labes to have one dimension less (like in nn.CrossEntropyLoss).

    """

    def __init__(
        self,
        gamma: float = 2.0,
        reduction: str = "mean",
        normalized: bool = False,
        reduced_threshold: Optional[float] = None,
    ):
        """

        :param alpha:
        :param gamma:
        :param ignore_index: If not None, targets with given index are ignored
        :param reduced_threshold: A threshold factor for computing reduced focal loss
        """
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.reduced_threshold = reduced_threshold
        self.normalized = normalized

    def forward(self, inputs: Tensor, targets: Tensor) -> Tensor:
        """

        Args:
            inputs: [B,C,H,W] tensor
            targets: [B,H,W] tensor

        Returns:

        """
        return softmax_focal_loss_with_logits(
            inputs,
            targets,
            gamma=self.gamma,
            reduction=self.reduction,
            normalized=self.normalized,
            reduced_threshold=self.reduced_threshold,
        )

### Training

In [11]:
import time

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from transformers import OneFormerProcessor, AutoModelForUniversalSegmentation


def teacher_forward(teacher, **inputs):
    raw_out = teacher(**inputs)

    class_queries_logits = raw_out.class_queries_logits  # [batch_size, num_queries, num_classes+1]
    masks_queries_logits = raw_out.masks_queries_logits  # [batch_size, num_queries, height, width]

    # Remove the null class `[..., :-1]`
    masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
    masks_probs = masks_queries_logits.sigmoid()  # [batch_size, num_queries, height, width]

    # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
    segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)  # not probability
    # print("sum: ", torch.sum(segmentation_logits[:, :, 0, 0], dim=1))
    segmentation_logits = F.interpolate(segmentation_logits, size=(128, 128),
                                        mode='bilinear', align_corners=False)

    semantic_segmentation = segmentation_logits.softmax(dim=1)

    semantic_segmentation = semantic_segmentation.argmax(dim=1)
    # print("segmentation_logits: ", segmentation_logits.shape)
    # print("semantic_segmentation: ", semantic_segmentation.shape)
    return segmentation_logits, semantic_segmentation


def train(stud_id, path_to_save_model=None):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # print(torch.cuda.is_available())
    print('Device: ', device)
    torch.cuda.empty_cache()

    processor_teacher = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
    teacher = AutoModelForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large").to(device)

    # processor_student = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") # inutile
    student = SwinDeepLabV3Plus(num_classes=133).to(device)

    # freezing backbone's parameters
    for param in student.backbone.parameters():
        param.requires_grad = False

    # Load dataloaders
    train_dl = torch.load(f'/content/drive/MyDrive/data_loaders/Train_dl_{stud_id}.pt')
    val_dl = torch.load(f'/content/drive/MyDrive/data_loaders/Validation_dl_{stud_id}.pt')

    # Hyper-params settings # todo change
    learning_rate = 1e-03  # learning rate
    milestones = [5, 10, 15]  # the epochs after which the learning rate is adjusted by gamma
    gamma = 0.1  # gamma correction to the learning rate, after reaching the milestone epochs
    weight_decay = 1e-05  # weight decay (L2 penalty)
    epochs = 20
    T = 2  # Temperature

    optimizer = optim.Adam(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
    use_scheduler = True  # use MultiStepLR scheduler
    if use_scheduler:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

    # Weights sum up to 1
    # TODO paper 2015 dice che deve essere più alto al resto... online è sempre più basso del resto (tipo regularization)
    kl_loss_weight = 0.8
    ce_loss_weight = 0.1
    dice_loss_weight = 0.1

    ce_loss = nn.CrossEntropyLoss()  # todo usiamo focal al posto suo?
    # ce_loss = CrossEntropyFocalLoss()  # focal loss
    dice_loss = DiceLoss()

    # Todo training loop
    teacher.eval()  # Teacher set to evaluation mode

    train_loss = []
    val_loss = []

    kl_l = []
    dice_l = []
    focal_l = []

    # TODO add running loss and fix training loop
    for epoch in range(epochs):

        student.train()
        running_loss = 0
        running_kl = 0
        running_dice = 0
        running_focal = 0
        n = 0
        for i, (images, _) in enumerate(train_dl):
            print("# epoch: ", epoch, " - i: ", i)
            batch_size = images.shape[0]
            n += batch_size

            semantic_inputs = processor_teacher(images=images, task_inputs=["semantic"], return_tensors="pt",
                                                do_rescale=False).to(device)
            semantic_inputs["task_inputs"] = semantic_inputs["task_inputs"].repeat(batch_size, 1)
            # print("pixel_values: ", semantic_inputs['pixel_values'].shape)

            optimizer.zero_grad()

            # Forward pass of the teacher model - do not save gradients to not change the teacher's weights
            with torch.no_grad():
                teacher_logits, pseudo_labels = teacher_forward(teacher, **semantic_inputs)
            # print("teacher_logits: ", teacher_logits.shape)
            # print("pseudo_labels: ", pseudo_labels.shape)

            student_input = F.interpolate(semantic_inputs["pixel_values"], size=(512, 512),
                                          mode='bilinear', align_corners=False)
            # print("student_input: ", student_input.shape)

            # Forward pass of the student model
            student_logits, preds = student(student_input)
            # print("student_logits: ", student_logits.shape)  # not probability
            # print("preds: ", preds.shape)

            # Soften the student logits by applying softmax first and log() second
            soft_targets = F.softmax(teacher_logits / T, dim=1)
            soft_prob = F.log_softmax(student_logits / T, dim=1)

            # [batch * width * height, classes]
            soft_targets = soft_targets.permute(0, 2, 3, 1).reshape(-1, 133)
            soft_prob = soft_prob.permute(0, 2, 3, 1).reshape(-1, 133)

            # Calculate the soft targets loss. ["Distilling the knowledge in a neural network"]
            kl_div_res = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (T ** 2)

            # Calculate the true label loss
            ce_res = ce_loss(student_logits, pseudo_labels)

            # Calculate the true label loss
            dice_res = dice_loss(student_logits, pseudo_labels)

            # Weighted sum of the two losses
            loss = kl_loss_weight * kl_div_res + ce_loss_weight * ce_res + dice_loss_weight * dice_res

            loss.backward()

            optimizer.step()
            running_loss += loss * batch_size
            running_kl += kl_loss_weight * kl_div_res * batch_size
            running_dice += ce_loss_weight * ce_res * batch_size
            running_focal += dice_loss_weight * dice_res * batch_size

        train_loss.append(running_loss.detach().cpu() / n)
        kl_l.append(running_kl.detach().cpu() / n)
        dice_l.append(running_dice.detach().cpu() / n)
        focal_l.append(running_focal.detach().cpu() / n)

        # Todo validation: compute metrics
        student.eval()
        with torch.no_grad():
            running_loss = 0
            n = 0
            for i, (images, _) in enumerate(val_dl):
                batch_size = images.shape[0]
                n += batch_size

                semantic_inputs = processor_teacher(images=images, task_inputs=["semantic"], return_tensors="pt",
                                                    do_rescale=False).to(device)
                semantic_inputs["task_inputs"] = semantic_inputs["task_inputs"].repeat(batch_size, 1)
                # print("pixel_values: ", semantic_inputs['pixel_values'].shape)

                teacher_logits, pseudo_labels = teacher_forward(teacher, **semantic_inputs)

                student_logits, preds = student(semantic_inputs["pixel_values"])

                soft_targets = F.softmax(teacher_logits / T, dim=1)
                soft_prob = F.log_softmax(student_logits / T, dim=1)

                # [batch * width * height, classes]
                soft_targets = soft_targets.permute(0, 2, 3, 1).reshape(-1, 133)
                soft_prob = soft_prob.permute(0, 2, 3, 1).reshape(-1, 133)

                kl_div_res = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (T ** 2)
                ce_res = ce_loss(student_logits, pseudo_labels)
                dice_res = dice_loss(student_logits, pseudo_labels)

                loss = kl_loss_weight * kl_div_res + ce_loss_weight * ce_res + dice_loss_weight * dice_res

                running_loss += loss * batch_size

            val_loss.append(running_loss.detach().cpu() / n)

        if use_scheduler:
            scheduler.step()

        # Save and plot
        if (epoch + 1) % 10 == 0:  # todo %10

            if path_to_save_model is not None:
                checkpoint_path = path_to_save_model + f'student_ckpt_epoch_{epoch + 1}.pth'
                torch.save(student.state_dict(), checkpoint_path)

            plt.figure(figsize=(10, 6))
            plt.plot(range(epoch + 1), train_loss, label='Training Loss', marker='o')
            plt.plot(range(epoch + 1), val_loss, label='Validation Loss', marker='o')

            plt.plot(range(epoch + 1), kl_l, label='KL Loss', marker='o')
            plt.plot(range(epoch + 1), dice_l, label='Dice Loss', marker='o')
            plt.plot(range(epoch + 1), focal_l, label='Focal Loss', marker='o')

            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.title('Training and Validation Losses')
            plt.legend()
            plt.grid(True)

            plot_path = f'epoch_{epoch}_lr_{learning_rate}_temp_{T}.png'

            plt.savefig('/content/drive/MyDrive/' + plot_path)

            # Save the loss plot
            '''loss_plot_path = path_to_save_model + f'loss_plot_epoch_{epoch + 1}.png'
            plt.savefig(loss_plot_path)'''

            plt.show()

            visualize_segmentation(pseudo_labels[0], plot_save_path= 'pseudo_labels_0_' + plot_path)
            visualize_segmentation(pseudo_labels[1], plot_save_path= 'pseudo_labels_1_' + plot_path)

            visualize_segmentation(preds[0], plot_save_path= 'pred_0_' + plot_path)
            visualize_segmentation(preds[1], plot_save_path= 'pred_1_' + plot_path)


In [None]:
print("### Starting training... ###")
t0 = time.time()
train(stud_id=1, path_to_save_model='/content/drive/MyDrive/')
t1 = time.time()
print("training/validation time: {0:.2f}s".format(t1 - t0))
print("### DataLoader ready ###")