In [None]:
from pycocotools.coco import COCO
import numpy as np
from matplotlib import pyplot as plt
import random
import torch
import torchvision
from torchvision.transforms import v2
import copy
import math
from collections import OrderedDict
if torch.cuda.is_available(): 
 dev = "cuda:0" 
else: 
 dev = "cpu" 
DEVICE = torch.device(dev)

import sys
ARGNAME = sys.argv[-1]
print(ARGNAME)
if not ARGNAME.startswith('C:\\'):
    import matplotlib
    matplotlib.use('Agg')

In [None]:
class ImageGenerator(torch.utils.data.Dataset):
    def __init__(self, imgdir, cocofilename=r'_annotations.coco.json', size=None):
        """ Returns a set of frames with their associated label. 

          Args:
            data: Data containing video file paths.
        """
        self.imgdir = imgdir
        self.coco = COCO(self.imgdir+'\\'+cocofilename)
        self.list = list(self.coco.imgs.keys())
        self.resizer = None
        if size != None:
            self.resizer = torchvision.transforms.v2.Resize(size,torchvision.transforms.InterpolationMode.BICUBIC,antialias=False)
    def __len__(self):
        return len(self.list)
    def __getitem__(self, idx):
        img = self.coco.imgs[idx]
        filename = self.imgdir+'\\'+img['file_name']
        image = torchvision.io.read_image(filename,torchvision.io.ImageReadMode.RGB)
        cat_ids = self.coco.getCatIds()
        anns_ids = self.coco.getAnnIds(imgIds=img['id'], catIds=cat_ids, iscrowd=None)
        anns = self.coco.loadAnns(anns_ids)
        mask = np.zeros((img['height'], img['width']), dtype=int)
        for i in range(len(anns)):
            mask |= self.coco.annToMask(anns[i])
        mask = torch.from_numpy(mask)
        if self.resizer != None:
            image, mask = self.resizer(torchvision.tv_tensors.Image(image), torchvision.tv_tensors.Mask(mask))
        return image, mask
def display(display_list,save=None):
    plt.figure(figsize=(3*5, len(display_list)*5))
    for i in range(len(display_list)):
        plt.subplot(len(display_list), 3, i*3+1)
        plt.title('Input Image')
        plt.imshow(display_list[i]['image'].permute(1,2,0)) #pytorch loves channels first but matplotlib loves channels last
        plt.axis('off')
        plt.subplot(len(display_list), 3, i*3+2)
        plt.title('True Mask')
        plt.imshow(display_list[i]['mask'])
        plt.axis('off')
        if 'pred' in display_list[i]:
            plt.subplot(len(display_list), 3, i*3+3)
            plt.title('Predicted Mask')
            plt.imshow(display_list[i]['pred'])
            plt.axis('off')
    if save != None:
        plt.savefig(save) #save before show
    plt.show()

# Networks

In [None]:
def ConvBlock(num,channels,filters,addpool=False,reversed=False):
    block = torch.nn.Sequential()
    if addpool:
        block.append(torch.nn.MaxPool2d(2, stride=2))
    if reversed:
        if channels == None:
            raise NotImplementedError("nah son") #Must have a channel
        for i in range(num-1):
            block.append(torch.nn.Conv2d(channels,channels,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
            block.append(torch.nn.BatchNorm2d(channels))
            block.append(torch.nn.ReLU(inplace=True))
        block.append(torch.nn.Conv2d(channels,filters,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
        block.append(torch.nn.BatchNorm2d(filters))
        block.append(torch.nn.ReLU(inplace=True))
    else:
        for i in range(num):
            if channels == None:
                block.append(torch.nn.LazyConv2d(filters,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
            else:
                block.append(torch.nn.Conv2d(channels,filters,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
            block.append(torch.nn.BatchNorm2d(filters))
            block.append(torch.nn.ReLU(inplace=True))
            channels = filters
    return block
class BottleResBlock(torch.nn.Module):
    def __init__(self,channels,bottlechan,filters,stride=1,dilation=1):
        super().__init__()
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(channels,bottlechan,kernel_size=(1,1),padding='valid',bias=False), #Bias is pointless for BatchNorm
            torch.nn.BatchNorm2d(bottlechan),
            torch.nn.ReLU(),
            torch.nn.Conv2d(bottlechan,bottlechan,kernel_size=(3,3),padding=1+dilation-1,stride=stride,dilation=dilation,bias=False),
            torch.nn.BatchNorm2d(bottlechan),
            torch.nn.ReLU(),
            torch.nn.Conv2d(bottlechan,filters,kernel_size=(1,1),padding='valid',bias=False),
            torch.nn.BatchNorm2d(filters)
        )
        self.skip = torch.nn.Sequential(
            torch.nn.Conv2d(channels,filters,kernel_size=(1,1),padding='valid',stride=stride,bias=False),
            torch.nn.BatchNorm2d(filters),
        ) if channels != filters or stride != 1 else torch.nn.Identity()
    def forward(self,x):
        x_1 = self.skip(x)
        x_2 = self.bottleneck(x)
        y = torch.add(x_1,x_2)
        return y
class InvertedBottleResBlock(torch.nn.Module):
    def __init__(self,channels,expansion,filters,stride=1):
        super().__init__()
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(channels,channels*expansion,kernel_size=(1,1),padding='valid',bias=False), #Bias is pointless for BatchNorm
            torch.nn.BatchNorm2d(channels*expansion),
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels*expansion,channels*expansion,kernel_size=(3,3),padding=1,stride=stride,groups=channels*expansion,bias=False), #DWConv
            torch.nn.BatchNorm2d(channels*expansion),
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels*expansion,filters,kernel_size=(1,1),padding='valid',bias=False),
            torch.nn.BatchNorm2d(filters) #No ReLU here
        )
        self.skip = stride == 1 and channels == filters
    def forward(self,x):
        if self.skip:
            x_1 = self.bottleneck(x)
            y = torch.add(x_1,x)
            return y
        else:
            y = self.bottleneck(x)
            return y
class PPM(torch.nn.Module):
    def __init__(self,in_channels,pools=[1,2,4,8]):
        super().__init__()
        def poollayer(size):
            block = torch.nn.Sequential()
            if size != 1:
                block.append(torch.nn.AvgPool2d(size, stride=size))
            block.append(torch.nn.Conv2d(in_channels,in_channels//4,kernel_size=(1,1),bias=False))
            block.append(torch.nn.BatchNorm2d(in_channels//4))
            block.append(torch.nn.ReLU())
            if size != 1:
                block.append(torch.nn.Upsample(scale_factor=size, mode='bilinear'))
            return block
        self.pools = torch.nn.ModuleList([poollayer(x) for x in pools])
        
        #self.pool1 = poollayer(pools[0])
        #self.pool2 = poollayer(pools[1])
        #self.pool3 = poollayer(pools[2])
        #self.pool4 = poollayer(pools[3])
    def forward(self, x):
        x_pooled = [pool(x) for pool in self.pools]
        x_pooled.append(x)
        y = torch.cat(x_pooled,1)
        #x1 = self.pool1(x)
        #x2 = self.pool2(x)
        #x3 = self.pool3(x)
        #x4 = self.pool4(x)
        
        #y = torch.cat((x,x1,x2,x3,x4),1)
        return y
class ASPP(torch.nn.Module):
    def __init__(self,in_channels,out_channel_per_branch,dilations=[6,12,18]):
        super().__init__()
        def dilatedlayer(rate):
            block = torch.nn.Sequential()
            #PADDING HAS TO BE LIKE THIS
            block.append(torch.nn.Conv2d(in_channels,out_channel_per_branch,kernel_size=(3,3),padding=rate,dilation=rate,bias=False))
            block.append(torch.nn.BatchNorm2d(out_channel_per_branch))
            block.append(torch.nn.ReLU())
            return block
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,out_channel_per_branch,kernel_size=(1,1),bias=False),
            torch.nn.BatchNorm2d(out_channel_per_branch),
            torch.nn.ReLU()
        )
        self.dilations = torch.nn.ModuleList([dilatedlayer(x) for x in dilations])
        self.globalpool = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool2d(1),
            torch.nn.Conv2d(in_channels,out_channel_per_branch,kernel_size=(1,1),bias=False),
            torch.nn.BatchNorm2d(out_channel_per_branch),
            torch.nn.ReLU()
        )
        self.finalconv = torch.nn.Sequential(
            torch.nn.Conv2d(out_channel_per_branch*(2+len(self.dilations)),out_channel_per_branch,kernel_size=(1,1),bias=False),
            torch.nn.BatchNorm2d(out_channel_per_branch),
            torch.nn.ReLU()
        )
    def forward(self, x):
        input_spatial_dim = x.size()[2:]
        x_pooled = [pool(x) for pool in self.dilations]
        x_pooled.append(self.conv1(x))
        x_pooled.append(torch.nn.functional.interpolate(self.globalpool(x), input_spatial_dim, mode='bilinear', align_corners=True))
        y = self.finalconv(torch.cat(x_pooled,1))
        return y
class ResNet(torch.nn.Module):
    def __init__(self,in_channels,classes,layers=[3,4,6,3]):
        super().__init__()
        def ResNetBlock(num,channels,bottlechan,filters,stride=1,addpool=False):
            blocks = torch.nn.Sequential()
            if addpool:
                blocks.append(torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
            for i in range(num):
                if stride != 1 and i == 0:
                    blocks.append(BottleResBlock(channels,bottlechan,filters,stride))
                else:
                    blocks.append(BottleResBlock(channels,bottlechan,filters))
                channels = filters
            return blocks
        self.blocks = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv2d(in_channels,64,kernel_size=(7,7),stride=2,padding=3,bias=False),
                torch.nn.BatchNorm2d(64),
                torch.nn.ReLU(),
            ),
            ResNetBlock(layers[0],64,64,256,addpool=True),
            ResNetBlock(layers[1],256,128,512,stride=2),
            ResNetBlock(layers[2],512,256,1024,stride=2),
            ResNetBlock(layers[3],1024,512,2048,stride=2),
        ])
        self.pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = torch.nn.Flatten(start_dim=1, end_dim=-1)
        self.classifier = torch.nn.Linear(2048,classes)
    def forward(self, x):
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
        x = self.pooling(x)
        x = self.flatten(x)
        y = self.classifier(x)
        return y
class UNet(torch.nn.Module):
    def __init__(self,in_channels,classes,convblocks=None,channels=[64,128,256,512,1024],upsample_output=1):
        super().__init__()
        self.channels = channels
        if type(convblocks) is torch.nn.ModuleList:
            self.encoderblocks = torch.nn.ModuleList([
                convblocks[0],
                convblocks[1],
                convblocks[2],
                convblocks[3],
                convblocks[4]
            ])
            self.upsamplers = torch.nn.ModuleList([
                torch.nn.Sequential(
                    torch.nn.LazyConvTranspose2d(self.channels[3],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[3],self.channels[2],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[2],self.channels[1],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[1],self.channels[0],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                )
            ])
            self.decoderblocks = torch.nn.ModuleList([
                UNet._block(None,self.channels[3]),
                UNet._block(None,self.channels[2]),
                UNet._block(None,self.channels[1]),
                UNet._block(None,self.channels[0])
            ])
        else:
            self.encoderblocks = torch.nn.ModuleList([
                UNet._block(in_channels,self.channels[0]),
                UNet._block(self.channels[0],self.channels[1],True),
                UNet._block(self.channels[1],self.channels[2],True),
                UNet._block(self.channels[2],self.channels[3],True),
                UNet._block(self.channels[3],self.channels[4],True)
            ])
            self.upsamplers = torch.nn.ModuleList([
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[4],self.channels[3],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[3],self.channels[2],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[2],self.channels[1],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[1],self.channels[0],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                )
            ])
            self.decoderblocks = torch.nn.ModuleList([
                UNet._block(self.channels[4],self.channels[3]),
                UNet._block(self.channels[3],self.channels[2]),
                UNet._block(self.channels[2],self.channels[1]),
                UNet._block(self.channels[1],self.channels[0])
            ])
        self.classifier = torch.nn.Conv2d(self.channels[0],classes,kernel_size=(1,1))
        self.upsampler = torch.nn.Upsample(scale_factor=upsample_output, mode='bilinear') if upsample_output != 1 else torch.nn.Identity()
    @staticmethod
    def _block(channels,filters,pool=False):
        block = torch.nn.Sequential()
        if pool:
            block.append(torch.nn.MaxPool2d(2, stride=2))
        if channels == None:
            block.append(torch.nn.LazyConv2d(filters,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        else:
            block.append(torch.nn.Conv2d(channels,filters,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        block.append(torch.nn.BatchNorm2d(filters))
        block.append(torch.nn.ReLU(inplace=True))
        block.append(torch.nn.Conv2d(filters,filters,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        block.append(torch.nn.BatchNorm2d(filters))
        block.append(torch.nn.ReLU(inplace=True))
        return block
    def forward(self, x):
        x1 = self.encoderblocks[0](x)
        x2 = self.encoderblocks[1](x1)
        x3 = self.encoderblocks[2](x2)
        x4 = self.encoderblocks[3](x3)
        x5 = self.encoderblocks[4](x4)
        x4_2 = self.upsamplers[0](x5)
        x4_2 = torch.cat((x4,x4_2),1)
        x4_2 = self.decoderblocks[0](x4_2)
        x3_2 = self.upsamplers[1](x4_2)
        x3_2 = torch.cat((x3,x3_2),1)
        x3_2 = self.decoderblocks[1](x3_2)
        x2_2 = self.upsamplers[2](x3_2)
        x2_2 = torch.cat((x2,x2_2),1)
        x2_2 = self.decoderblocks[2](x2_2)
        x1_2 = self.upsamplers[3](x2_2)
        x1_2 = torch.cat((x1,x1_2),1)
        x1_2 = self.decoderblocks[3](x1_2)
        y = self.classifier(x1_2)
        y = self.upsampler(y)
        return y
class UNetPPM(UNet):
    def __init__(self,in_channels,classes,convblocks=None,encoderblocks=[2,2,2,2,2],decoderblocks=[2,2,2,2],upsample_output=1,pools=[1,2,4,8]):
        super().__init__(in_channels,classes,convblocks,encoderblocks,decoderblocks,upsample_output)
        self.decoderblocks[3] = torch.nn.Sequential(
            PPM(self.channels[1],pools),
            ConvBlock(decoderblocks[3],(self.channels[1]//4)*len(pools)+self.channels[1],self.channels[0]),
        )
        #self.classifier = torch.nn.Conv2d(self.channels[0],CLASSES,kernel_size=(1,1),padding='same')
class UNetASPP(UNet):
    def __init__(self,in_channels,classes,convblocks=None,encoderblocks=[2,2,2,2,2],decoderblocks=[2,2,2,2],upsample_output=1,dilations=[6,12,18]):
        super().__init__(in_channels,classes,convblocks,encoderblocks,decoderblocks,upsample_output)
        self.decoderblocks[3] = torch.nn.Sequential(
            #ASPP(self.channels[1],self.channels[0],dilations),
            ConvBlock(decoderblocks[3],self.channels[1],self.channels[0]),
            ASPP(self.channels[0],self.channels[0],dilations),
        )
class SegNet(torch.nn.Module):
    def __init__(self,in_channels,classes,convblocks=None,ptdec=False):
        super().__init__()
        if type(convblocks) is torch.nn.ModuleList:
            self.encoderblocks = torch.nn.ModuleList([
                convblocks[0],
                convblocks[1],
                convblocks[2],
                convblocks[3],
                convblocks[4]
            ])
        else:
            backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights='None')
            self.encoderblocks = torch.nn.ModuleList([
                SegNet._seq(backbone.features[0:6]),
                SegNet._seq(backbone.features[7:13]),
                SegNet._seq(backbone.features[14:23]),
                SegNet._seq(backbone.features[24:33]),
                SegNet._seq(backbone.features[34:43]),
            ])
            del backbone
        self.downsamplers = torch.nn.ModuleList([
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True)
        ])
        self.upsamplers = torch.nn.ModuleList([
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2)
        ])
        if ptdec:
            backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights='VGG16_BN_Weights.IMAGENET1K_V1')
        else:
            backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights=None)
        self.decoderblocks = torch.nn.ModuleList([
            SegNet._seq(backbone.features[34:43]),
            SegNet._seq(backbone.features[27:33]), #24:33
            SegNet._seq(backbone.features[17:23]), #14:23
            SegNet._seq(backbone.features[10:13]), #7:13
            SegNet._seq(backbone.features[3:6]), #0:6
        ])
        self.decoderblocks[1].append(torch.nn.Conv2d(512,256,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        self.decoderblocks[1].append(torch.nn.BatchNorm2d(256))
        self.decoderblocks[1].append(torch.nn.ReLU(inplace=True))
        self.decoderblocks[2].append(torch.nn.Conv2d(256,128,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        self.decoderblocks[2].append(torch.nn.BatchNorm2d(128))
        self.decoderblocks[2].append(torch.nn.ReLU(inplace=True))
        self.decoderblocks[3].append(torch.nn.Conv2d(128,64,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        self.decoderblocks[3].append(torch.nn.BatchNorm2d(64))
        self.decoderblocks[3].append(torch.nn.ReLU(inplace=True))
        self.decoderblocks[4].append(torch.nn.Conv2d(64,64,kernel_size=(3,3),padding=(1,1),bias=False)) #Bias is pointless for BatchNorm
        self.decoderblocks[4].append(torch.nn.BatchNorm2d(64))
        self.decoderblocks[4].append(torch.nn.ReLU(inplace=True))
        self.decoderblocks[4].append(torch.nn.Conv2d(64,classes,kernel_size=(1,1))) #Classifier
        del backbone
    @staticmethod
    def _seq(blocks):
        newblocks = torch.nn.Sequential()
        for i in blocks:
            newblocks.append(i)
        return newblocks
    def forward(self, x):
        x = self.encoderblocks[0](x)
        x, i1 = self.downsamplers[0](x)
        x = self.encoderblocks[1](x)
        x, i2 = self.downsamplers[0](x)
        x = self.encoderblocks[2](x)
        x, i3 = self.downsamplers[0](x)
        x = self.encoderblocks[3](x)
        x, i4 = self.downsamplers[0](x)
        x = self.encoderblocks[4](x)
        x, i5 = self.downsamplers[0](x)
        x = self.upsamplers[0](x, indices=i5)
        x = self.decoderblocks[0](x)
        x = self.upsamplers[1](x, indices=i4)
        x = self.decoderblocks[1](x)
        x = self.upsamplers[2](x, indices=i3)
        x = self.decoderblocks[2](x)
        x = self.upsamplers[3](x, indices=i2)
        x = self.decoderblocks[3](x)
        x = self.upsamplers[4](x, indices=i1)
        x = self.decoderblocks[4](x)
        return x
class FCNBase(torch.nn.Module):
    def __init__(self,in_channels,classes,convblocks,upsamplevar=8):
        super().__init__()
        if not type(convblocks) is torch.nn.ModuleList:
            assert(False)
        if upsamplevar != 32 and upsamplevar != 16 and upsamplevar != 8:
            assert(False)
        self.fcnvar=upsamplevar
        self.encoderblocks = torch.nn.ModuleList([
            convblocks[0],
            convblocks[1],
            convblocks[2],
            convblocks[3],
            convblocks[4]
        ])
        #NOTE: I think padding in ConvTranspose2d works similar to centercrop
        #self.upsampler4 = torch.nn.ConvTranspose2d(classes,classes,kernel_size=(4,4),stride=(2,2),padding=(1,1),bias=False) #4 Kernel size, as intended
        #self.upsampler5 = torch.nn.ConvTranspose2d(classes,classes,kernel_size=(4,4),stride=(2,2),padding=(1,1),bias=False) #in the paper's source code
        self.upsampler4 = torch.nn.Upsample(scale_factor=2, mode='bilinear') #The original code uses ConvTranspose2d but with frozen params
        self.upsampler5 = torch.nn.Upsample(scale_factor=2, mode='bilinear') #Therefore, upsampling is replicated with bilinear upsampling
        self.classifier3 = torch.nn.LazyConv2d(classes,kernel_size=(1,1))
        self.classifier4 = torch.nn.LazyConv2d(classes,kernel_size=(1,1))
        self.classifier5 = torch.nn.LazyConv2d(classes,kernel_size=(1,1))
        #self.finalupsampler = torch.nn.ConvTranspose2d(classes,classes,kernel_size=(self.fcnvar*2,self.fcnvar*2),
        #                                               stride=(self.fcnvar,self.fcnvar),padding=(self.fcnvar//2,self.fcnvar//2),bias=False)
        self.finalupsampler = torch.nn.Upsample(scale_factor=upsamplevar, mode='bilinear')
    def forward(self, x0):
        x1 = self.encoderblocks[0](x0) #x1 = /2
        x2 = self.encoderblocks[1](x1) #x2 = /4
        x3 = self.encoderblocks[2](x2) #x3 = /8
        x4 = self.encoderblocks[3](x3) #x4 = /16
        x5 = self.encoderblocks[4](x4) #x5 = /32
        p5 = self.classifier5(x5)
        if self.fcnvar == 32:
            return self.finalupsampler(p5)
        p4 = self.classifier4(x4*0.01) #0.01 = Scaling as intended in the paper
        p5 = self.upsampler5(p5)
        p4 = torch.add(p4,p5)
        if self.fcnvar == 16:
            return self.finalupsampler(p4)
        p3 = self.classifier3(x3*0.0001) #0.0001 = Scaling as intended in the paper
        p4 = self.upsampler4(p4)
        p3 = torch.add(p3,p4)
        return self.finalupsampler(p3)
    def visualizeout(self, x0):
        x1 = self.encoderblocks[0](x0) #x1 = /2
        x2 = self.encoderblocks[1](x1) #x2 = /4
        x3 = self.encoderblocks[2](x2) #x3 = /8
        x4 = self.encoderblocks[3](x3) #x4 = /16
        x5 = self.encoderblocks[4](x4) #x5 = /32
        p5 = self.classifier5(x5)
        p4 = self.classifier4(x4*0.01)
        p3 = self.classifier3(x3*0.0001)
        return (p5,p4,p3)
def INF(B,H,W):
    return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(torch.nn.Module):
    """ Criss-Cross Attention Module, copied straight from the repo"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        
        self.query_conv = torch.nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = torch.nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = torch.nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = torch.nn.Softmax(dim=3)
        self.INF = INF
        self.gamma = torch.nn.Parameter(torch.zeros(1))
    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H)
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x
class RCCAModule(torch.nn.Module):
    """Also copied too"""
    def __init__(self, in_channels, out_channels, num_classes):
        super(RCCAModule, self).__init__()
        inter_channels = in_channels // 4
        self.conva = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            torch.nn.BatchNorm2d(inter_channels),
            torch.nn.ReLU(),
        )
        self.cca = CrissCrossAttention(inter_channels)
        self.convb = torch.nn.Sequential(
            torch.nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            torch.nn.BatchNorm2d(inter_channels),
            torch.nn.ReLU(),
        )
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout2d(0.1),
            torch.nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        )
    def forward(self, x, recurrence=1):
        output = self.conva(x)
        for i in range(recurrence):
            output = self.cca(output)
        output = self.convb(output)
        output = self.bottleneck(torch.cat([x, output], 1))
        return output
class DilatedResNet(torch.nn.Module):
    def __init__(self,in_channels,classes,layers=[3,4,6,3]):
        super().__init__()
        def ResNetBlock(num,channels,bottlechan,filters,stride=1,dilation=1,addpool=False):
            blocks = torch.nn.Sequential()
            if addpool:
                blocks.append(torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
            for i in range(num):
                if stride != 1 and i == 0:
                    blocks.append(BottleResBlock(channels,bottlechan,filters,stride,dilation=dilation))
                else:
                    blocks.append(BottleResBlock(channels,bottlechan,filters,dilation=dilation))
                channels = filters
            return blocks
        self.blocks = torch.nn.Sequential(
            torch.nn.Sequential(
                torch.nn.Conv2d(in_channels,64,kernel_size=(7,7),stride=2,padding=3,bias=False),
                torch.nn.BatchNorm2d(64),
                torch.nn.ReLU(),
            ),
            ResNetBlock(layers[0],64,64,256,addpool=True),
            ResNetBlock(layers[1],256,128,512,stride=2),
            ResNetBlock(layers[2],512,256,1024,stride=1,dilation=2),
            ResNetBlock(layers[3],1024,512,2048,stride=1,dilation=4),
        )
        self.head = torch.nn.Conv2d(2048,classes,kernel_size=(1,1),padding='same')
        self.upsampler = torch.nn.Upsample(scale_factor=8, mode='bilinear')
    def forward(self, x):
        x = self.blocks(x)
        y = self.head(x)
        y = self.upsampler(y)
        return y
class CCNet(DilatedResNet):
    def __init__(self,in_channels,classes,layers=[3,4,6,3],recurrence=2):
        super().__init__(in_channels,classes,layers=[3,4,6,3])
        self.head = RCCAModule(2048, 512, classes)
        self.upsampler = torch.nn.Upsample(scale_factor=8, mode='bilinear')
        self.recur = recurrence
    def forward(self, x):
        x = self.blocks(x)
        y = self.head(x,self.recur)
        y = self.upsampler(y)
        return y
class DilatedResNetWithPPM(DilatedResNet):
    def __init__(self,in_channels,classes,layers=[3,4,6,3]):
        super().__init__(in_channels,classes,layers=[3,4,6,3])
        self.ppm = PPM(2048)
        self.head = torch.nn.Conv2d((2048//4)*4+2048,classes,kernel_size=(1,1),padding='same')
    def forward(self, x):
        x = self.blocks(x)
        x = self.ppm(x)
        y = self.head(x)
        y = self.upsampler(y)
        return y

# Loss

In [None]:
class CrossentropyND(torch.nn.CrossEntropyLoss):
    """
    Network has to have NO NONLINEARITY!
    """
    def forward(self, inp, target):
        target = target.long()
        num_classes = inp.size()[1]

        i0 = 1
        i1 = 2

        while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
            inp = inp.transpose(i0, i1)
            i0 += 1
            i1 += 1

        inp = inp.contiguous()
        inp = inp.view(-1, num_classes)

        target = target.contiguous()
        target = target.view(-1,)

        return super(CrossentropyND, self).forward(inp, target)

class TopKLoss(CrossentropyND):
    """
    Network has to have NO LINEARITY!
    """
    def __init__(self, weight=None, ignore_index=-100, k=10):
        self.k = k
        super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)

    def forward(self, inp, target):
        #target = target[:, 0].long()
        res = super(TopKLoss, self).forward(inp, target)
        num_voxels = np.prod(res.shape)
        res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
        return res.mean()

def sum_tensor(inp, axes, keepdim=False):
    # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/tensor_utilities.py
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp
    
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn
    
class SoftDiceLoss(torch.nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
                 square=False):
        """
        paper: https://arxiv.org/pdf/1606.04797.pdf
        """
        super(SoftDiceLoss, self).__init__()

        self.square = square
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)

        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
        
class DC_and_CE_loss(torch.nn.Module):
    def __init__(self, apply_nonlin=None, aggregate="sum"):
        super(DC_and_CE_loss, self).__init__()
        self.aggregate = aggregate
        self.ce = CrossentropyND()
        self.dc = SoftDiceLoss(apply_nonlin)
    def forward(self, net_output, target):
        dc_loss = self.dc(net_output, target)
        ce_loss = self.ce(net_output, target)
        if self.aggregate == "sum":
            result = ce_loss + dc_loss
        else:
            raise NotImplementedError("nah son") # reserved for other stuff (later)
        return result

class TverskyLoss(torch.nn.Module):
    def __init__(self, apply_nonlin=None, alpha=0.3, beta=0.7, batch_dice=False, do_bg=True, smooth=1.,
                 square=False):
        """
        paper: https://arxiv.org/pdf/1706.05721.pdf
        """
        super(TverskyLoss, self).__init__()

        self.square = square
        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.alpha = alpha
        self.beta = beta

    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)


        tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)

        if not self.do_bg:
            if self.batch_dice:
                tversky = tversky[1:]
            else:
                tversky = tversky[:, 1:]
        tversky = tversky.mean()

        return -tversky

class FocalTversky_loss(torch.nn.Module):
    """
    paper: https://arxiv.org/pdf/1810.07842.pdf
    author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65
    """
    def __init__(self, tversky_kwargs, gamma=0.75):
        super(FocalTversky_loss, self).__init__()
        self.gamma = gamma
        self.tversky = TverskyLoss(**tversky_kwargs)

    def forward(self, net_output, target):
        tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)
        focal_tversky = torch.pow(tversky_loss, self.gamma)
        return focal_tversky

# Trainer

In [None]:
def ohe(y,classes):
    return torch.nn.functional.one_hot(y.long(),classes).permute(0,3,1,2).float() #cursed
def iou(pred, target, n_classes = 2, include_bg_class = True):
    ious = []
    # Assuming the shapes are BATCH x 1 x H x W => BATCH x H x W
    pred = pred.cpu()
    target = target.cpu()
    for cls in range(0 if include_bg_class else 1, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
        pred2 = pred == cls
        target2 = target == cls
        intersection = torch.logical_and(pred2,target2).count_nonzero((1,2)).data.cpu().numpy()
        #union = torch.logical_or(pred2,target2).count_nonzero((1,2)).data.cpu().item()
        union = pred2.count_nonzero((1,2)).data.cpu().numpy() + target2.count_nonzero((1,2)).data.cpu().numpy() - intersection
        ious.append((intersection/union))
    return ious
def fone(pred, target, n_classes = 2, include_bg_class = True): #a.k.a. Dice
    fone = []
    # Assuming the shapes are BATCH x 1 x H x W => BATCH x H x W
    pred = pred.cpu()
    target = target.cpu()
    for cls in range(0 if include_bg_class else 1, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
        pred2 = pred == cls
        target2 = target == cls
        intersectcount = torch.logical_and(pred2,target2).count_nonzero((1,2)).data.cpu().numpy()
        targetcount = target2.count_nonzero((1,2)).data.cpu().numpy()
        predcount = pred2.count_nonzero((1,2)).data.cpu().numpy()
        fone.append(2*intersectcount/(targetcount+predcount))
    return fone
class EarlyStopper:
    def __init__(self, metric_name='val_loss', lower_is_better=True, patience=4, grace_period=0):
        self.metric_name = metric_name
        self.patience = patience
        self.grace_period = grace_period
        self.grace_count = 0
        self.counter = 0
        self.lower_is_better = lower_is_better
        if self.lower_is_better:
            self.best_metric = float('inf')
        else:
            self.best_metric = -float('inf')
    def updatebest(self,metric):
        if self.lower_is_better:
            if metric < self.best_metric:
                self.best_metric = metric
                return True
        else:
            if metric > self.best_metric:
                self.best_metric = metric
                return True
        return False
    def early_stop(self, history, lower_is_better=True):
        metric = history[self.metric_name][-1]
        if self.grace_count < self.grace_period:
            self.grace_count += 1
            return -1 if self.updatebest(metric) else 0
            #return 0
        if self.updatebest(metric):
            self.counter = 0
            return -1
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return 1
        return 0
    def reset(self):
        self.counter = 0
        if self.lower_is_better:
            self.best_metric = float('inf')
        else:
            self.best_metric = -float('inf')

In [None]:
class Manager():
    def __init__(self, train_dl, val_dl, model, loss_fn, optimizer, preprocesser = None, scheduler = None, multi_losses = None):
        self.epoch = 0
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.preprocesser = preprocesser
        self.scheduler = scheduler
        self.multi_losses = multi_losses
        parametertensor = next(self.model.parameters())
        self.to_device = parametertensor.device
        self.rescaler = torchvision.transforms.v2.ToDtype(parametertensor.dtype,True) #infer the dtype of input
        self.stopnow = False
        self.personalbest = False
        self.pbrecord = None
        self.history = {'train_loss': [], 'train_miou': [], 'train_f1': [], 'val_loss': [], 'val_miou': [], 'val_f1': [], 'best_epoch': 0}
    def train_one_epoch(self):
        size = len(self.train_dl.dataset)
        count = 0
        batchsizes, losses, ious, fones = [], [], [], []
        self.model.train()
        for X, y in self.train_dl:
            #preprocess layers here
            X = self.rescaler(X).to(device=self.to_device,memory_format=torch.channels_last)
            if self.preprocesser != None:
                X, y = self.preprocesser(X, torchvision.tv_tensors.Mask(y))
            y = y.to(device=self.to_device,dtype=torch.int64)
            #ohe_y = ohe(y.to(device=self.to_device,dtype=torch.int64),2)
            # Compute prediction and loss
            out = self.model(X)
            #
            if type(out) is OrderedDict:
                if "out" in out:
                    loss = self.loss_fn(out["out"], y)
                    pred = out["out"]
                else:
                    assert(False) #Model doesn't produce output!
                if "aux" in out:
                    #Check whether aux losses are defined
                    if type(self.multi_losses) is OrderedDict and "aux" in self.multi_losses:
                        #Do aux losses
                        loss += self.loss_fn(out["aux"], y)*self.multi_losses["aux"]
                    else:
                        assert(False) #Unnecessary aux output! Disable it!
            else:
                #Good ol' tensor output
                loss = self.loss_fn(out, y)
                pred = out
            #print(pred.shape)
            #print(y.shape)
            # Backpropagation
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            # Stop propagating gradient
            loss = loss.item()
            # and relevant metrics
            pred = pred.argmax(1)
            biou = iou(pred,y,2,False)
            bfone = fone(pred,y,2,False)
            #
            batchsizes.append(len(X))
            losses.append(loss)
            for i in biou:
                ious.extend(i.tolist())
            for i in bfone:
                fones.extend(i.tolist())
            count = count + len(X)
            print(f"loss: {loss:>7f}  [{count:>5d}/{size:>5d}]",end = '\r')
        print()
        meanloss = np.average(np.array(losses),weights=np.array(batchsizes))
        meaniou = np.ma.mean(np.ma.masked_invalid(ious))
        meanfone = np.ma.mean(np.ma.masked_invalid(fones))
        print(f"Training Error:   mIoU: {(100*meaniou):>.02f}%, F1: {(100*meanfone):>.02f}%, Avg loss: {meanloss:>8f}")
        self.history['train_loss'].append(meanloss)
        self.history['train_miou'].append(meaniou)
        self.history['train_f1'].append(meanfone)
    def validate(self, earlystop):
        size = len(self.val_dl.dataset)
        self.model.eval()
        num_batches = len(self.val_dl)
        batchsizes, losses, ious, fones = [], [], [], []
        with torch.no_grad(): # reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
            for X, y in self.val_dl:
                X = self.rescaler(X).to(device=self.to_device,memory_format=torch.channels_last)
                y = y.to(device=self.to_device,dtype=torch.int64)
                #ohe_y = ohe(y.to(device=self.to_device,dtype=torch.int64),2)
                # Compute prediction and loss
                out = self.model(X)
                #
                if type(out) is OrderedDict:
                    if "out" in out:
                        loss = self.loss_fn(out["out"], y)
                        pred = out["out"]
                    else:
                        assert(False) #Model doesn't produce output!
                    if "aux" in out:
                        #Check whether aux losses are defined
                        if type(self.multi_losses) is OrderedDict and "aux" in self.multi_losses:
                            #Do aux losses
                            loss += self.loss_fn(out["aux"], y)*self.multi_losses["aux"]
                        else:
                            assert(False) #Unnecessary aux output! Disable it!
                else:
                    #Good ol' tensor output
                    loss = self.loss_fn(out, y)
                    pred = out
                loss = loss.item() #Itemized out of pytorch's tensor
                #print(pred.shape)
                #print(y.shape)
                pred = pred.argmax(1)
                biou = iou(pred,y,2,False)
                bfone = fone(pred,y,2,False)
                #
                batchsizes.append(len(X))
                losses.append(loss)
                for i in biou:
                    ious.extend(i.tolist())
                for i in bfone:
                    fones.extend(i.tolist())
        meanloss = np.average(np.array(losses),weights=np.array(batchsizes))
        meaniou = np.ma.mean(np.ma.masked_invalid(ious))
        meanfone = np.ma.mean(np.ma.masked_invalid(fones))
        print(f"Validating Error: mIoU: {(100*meaniou):>.02f}%, F1: {(100*meanfone):>.02f}%, Avg loss: {meanloss:>8f}")
        self.history['val_loss'].append(meanloss)
        self.history['val_miou'].append(meaniou)
        self.history['val_f1'].append(meanfone)
        #print("Counter: "+str(earlystop.counter)) #############################
        earlystopout = earlystop.early_stop(self.history)
        self.stopnow = earlystopout == 1
        self.personalbest = earlystopout == -1
    def evaluate(self, dataloader):
        size = len(dataloader.dataset)
        self.model.eval()
        num_batches = len(dataloader)
        batchsizes, losses, ious, fones = [], [], [], []
        with torch.no_grad(): # reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
            for X, y in dataloader:
                X = self.rescaler(X).to(device=self.to_device,memory_format=torch.channels_last)
                y = y.to(device=self.to_device,dtype=torch.int64)
                #ohe_y = ohe(y.to(device=self.to_device,dtype=torch.int64),2)
                out = self.model(X)
                #
                if type(out) is OrderedDict:
                    if "out" in out:
                        loss = self.loss_fn(out["out"], y)
                        pred = out["out"]
                    else:
                        assert(False) #Model doesn't produce output!
                    if "aux" in out:
                        #Check whether aux losses are defined
                        if type(self.multi_losses) is OrderedDict and "aux" in self.multi_losses:
                            #Do aux losses
                            loss += self.loss_fn(out["aux"], y)*self.multi_losses["aux"]
                        else:
                            assert(False) #Unnecessary aux output! Disable it!
                else:
                    #Good ol' tensor output
                    loss = self.loss_fn(out, y)
                    pred = out
                loss = loss.item() #Itemized out of pytorch's tensor
                #print(pred.shape)
                #print(y.shape)
                pred = pred.argmax(1)
                biou = iou(pred,y,2,False)
                bfone = fone(pred,y,2,False)
                #
                batchsizes.append(len(X))
                losses.append(loss)
                for i in biou:
                    ious.extend(i.tolist())
                for i in bfone:
                    fones.extend(i.tolist())
        meanloss = np.average(np.array(losses),weights=np.array(batchsizes))
        meaniou = np.ma.mean(np.ma.masked_invalid(ious))
        meanfone = np.ma.mean(np.ma.masked_invalid(fones))
        return {'loss': meanloss, 'miou': meaniou, 'f1': meanfone}
    def predict(self, X):
        self.model.eval()
        with torch.no_grad():
            X = self.rescaler(X.unsqueeze(0)).to(device=self.to_device,memory_format=torch.channels_last)
            out = self.model(X)
            #
            if type(out) is OrderedDict:
                if "out" in out:
                    pred = out["out"]
                else:
                    assert(False) #Model doesn't produce output!
            else:
                #Good ol' tensor output
                pred = out
            #print(pred.shape)
        return pred.squeeze(0)
    def train(self, epochs, earlystop, restore_best_weights = True):
        while epochs == None or self.epoch < epochs:
            if self.scheduler != None:
                print(f"Epoch {self.epoch+1:<6d} ({self.scheduler.get_last_lr()[0]:<5e})------------------------")
            else:
                print(f"Epoch {self.epoch+1:<6d}-------------------------------")
            #with torch.autograd.detect_anomaly():
            self.train_one_epoch()
            self.validate(earlystop)
            if self.personalbest:
                if restore_best_weights:
                    print('New personal best! Saving best weights...')
                    #torch.save(model.state_dict(),'temp-personalbest.pth')
                    self.pbrecord = (copy.deepcopy(self.model.state_dict()),copy.deepcopy(self.optimizer.state_dict()))
                else:
                    print('New personal best!')
                self.history['best_epoch'] = self.epoch #zero-indexed
            if self.scheduler != None:
                if type(self.scheduler) is torch.optim.lr_scheduler.ReduceLROnPlateau:
                    lastlr = self.scheduler.get_last_lr()[0]
                    self.scheduler.step(self.history['val_loss'][self.epoch])
                    if lastlr != self.scheduler.get_last_lr()[0]:
                        self.restorebestweights(True)
                        earlystop.counter = 0
                    #print("Cooldown: "+str(self.scheduler.cooldown_counter))
                else:
                    self.scheduler.step()
            print()
            if self.stopnow:
                print('Stop.')
                print()
                break
            print()
            self.epoch += 1
        print('Training ended.')
        if restore_best_weights:
            self.restorebestweights(False)
        return self.history
    def restorebestweights(self, revert = False):
        if revert:
            print('Reverting to best epoch...')
            self.model.load_state_dict(self.pbrecord[0])
            #self.optimizer.load_state_dict(self.pbrecord[1])
            del self.history['train_loss'][self.history['best_epoch']+1:]
            del self.history['train_miou'][self.history['best_epoch']+1:]
            del self.history['train_f1'][self.history['best_epoch']+1:]
            del self.history['val_loss'][self.history['best_epoch']+1:]
            del self.history['val_miou'][self.history['best_epoch']+1:]
            del self.history['val_f1'][self.history['best_epoch']+1:]
            self.epoch = self.history['best_epoch']
        else:
            print('Restoring best weights...')
            self.model.load_state_dict(self.pbrecord[0])

# Configurations

In [None]:
dataset_dir = r'C:\Users\User\Documents\AAAAA\eseg2\Fish EUS.v3i.coco-segmentation-fix'
DATASETNAME = 'FishV1'

train_img_dir = dataset_dir+r'\train'
val_img_dir = dataset_dir+r'\valid'
test_img_dir = dataset_dir+r'\test'

In [None]:
def init_lazies(model):
    # Initialize lazy modules by passing at least something
    with torch.no_grad():
        model.eval()
        out = model(torch.rand(1, 3, HEIGHT, WIDTH).to(device=DEVICE,dtype=DTYPE))
        print(out.shape)
def init_weights(m):
    print(type(m))
    if type(m) is torch.nn.Linear or type(m) is torch.nn.LazyLinear\
    or type(m) is torch.nn.Conv2d or type(m) is torch.nn.LazyConv2d\
    or type(m) is torch.nn.ConvTranspose2d or type(m) is torch.nn.LazyConvTranspose2d:
        if type(m.weight) is torch.nn.UninitializedParameter:
            print("Skipping uninitialized weights....")
            assert(False)
        print("Shape: "+str(m.weight.shape)+" + "+str(m.bias.shape if m.bias is not None else ""))
        if not m.weight.requires_grad:
            print("Weights frozen, skipping....")
        elif len(m.weight.shape) > 1:
            #torch.nn.init.xavier_normal_(m.weight, gain=1.4142135623730950488016887242097)
            torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            #torch.nn.init.normal_(m.weight,0,0.01)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
            print("Applied He initializer!")
        else:
            print("He initializer not applicable for 1 dimensional tensor.")
def init_eye(m,x=1):
    w = torch.zeros(m.weight.shape, dtype=m.weight.dtype)
    w[list(range(m.weight.shape[0])),list(range(m.weight.shape[1]))]=x
    m.weight.data.copy_(w)
    if m.bias is not None:
        torch.nn.init.zeros_(m.bias)
    print("Applied... I... initializer!")
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
    def __call__(self, tensor, tensor2):
        return tensor + torch.randn(tensor.size(),device=tensor.device) * self.std + self.mean, tensor2
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Mosaic(object):
    def __init__(self, minfrac=0.25, maxfrac=0.75):
        self.minfrac = minfrac
        self.maxfrac = maxfrac
    def __call__(self, images, labels):
        if images.shape[0] < 4:
            assert(False)
        x, y = images.shape[2], images.shape[3]
        imagetile = torch.zeros(images.shape,dtype=images.dtype,device=images.device)
        labeltile = torch.zeros(labels.shape,dtype=images.dtype,device=images.device)
        for i in range(images.shape[0]):
            xc, yc = torch.randint(round(x * self.minfrac), round(x * self.maxfrac), ()).item(), torch.randint(round(y * self.minfrac), round(y * self.maxfrac), ()).item()
            indices = torch.randperm(images.shape[0]).tolist()
            for j in range(4):
                match j:
                    case 0: #Top left
                        x1,y1,x2,y2 = 0,0,xc,yc
                        u1,v1,u2,v2 = x-xc,y-yc,x,y
                    case 1: #Top right
                        x1,y1,x2,y2 = xc,0,x,yc
                        u1,v1,u2,v2 = 0,y-yc,x-xc,y
                    case 2: #Bottom left
                        x1,y1,x2,y2 = 0,yc,xc,y
                        u1,v1,u2,v2 = x-xc,0,x,y-yc
                    case 3: #Bottom right
                        x1,y1,x2,y2 = xc,yc,x,y
                        u1,v1,u2,v2 = 0,0,x-xc,y-yc
                imagetile[i][0:images.shape[1],y1:y2,x1:x2] = images[indices[j]][0:images.shape[1],v1:v2,u1:u2]
                labeltile[i][y1:y2,x1:x2] = labels[indices[j]][v1:v2,u1:u2]
        return imagetile, labeltile
    def __repr__(self):
        return self.__class__.__name__ + '(minfrac={0}, maxfrac={1})'.format(self.minfrac, self.maxfrac)
class CutMixMask(object):
    def __init__(self, alpha=1.0):
        self.alpha = alpha
        self.dist = torch.distributions.Beta(torch.tensor([self.alpha]), torch.tensor([self.alpha]))
    def __call__(self, images, labels):
        if images.shape[0] < 2:
            assert(False)
        x, y = images.shape[2], images.shape[3]
        imagetile = torch.zeros(images.shape,dtype=images.dtype,device=images.device)
        labeltile = torch.zeros(labels.shape,dtype=images.dtype,device=images.device)
        for i in range(images.shape[0]):
            #Generate
            lam = float(self.dist.sample(()))
            xc, yc = torch.randint(0, x, ()).item(), torch.randint(0, y, ()).item()
            r = 0.5 * math.sqrt(1.0 - lam)
            xs, ys = round(r * x), round(r * y)
            x1, y1 = max(0,xc-xs), max(0,yc-ys)
            x2, y2 = min(x,xc+xs), min(y,yc+ys)
            #adjusted lam not needed, label will be a mask
            index = torch.randint(0, images.shape[0]-1, ()).item()
            index += 1 if index >= i else 0
            #Cut and Mix!
            imagetile[i][0:images.shape[1]] = images[i][0:images.shape[1]]
            imagetile[i][0:images.shape[1],y1:y2,x1:x2] = images[index][0:images.shape[1],y1:y2,x1:x2]
            labeltile[i] = labels[i]
            labeltile[i][y1:y2,x1:x2] = labels[index][y1:y2,x1:x2]
        return imagetile, labeltile
    def __repr__(self):
        return self.__class__.__name__ + '(alpha={0})'.format(self.alpha)

In [None]:
DTYPE = torch.float32
BATCH_SIZE = 5

EPOCHS = 1200
CLASSES = 2
AUXLOSSES = None
UNPACKER = None
scheduler = None

SAVEMODEL = False

#HEIGHT, WIDTH = 224, 224
HEIGHT, WIDTH = 448, 448

VGG16_Path = r'C:\Users\User\Documents\AAAAA\eseg2\VGG16-448^2-Fishv2_CLS_DE_AA+CutMixUp-SGD_1e-4_0,9-CE_wprop-BS8-E87-68,08%_34,04%_45,45%.pth'
PRE_VGG16_Path = r'C:\Users\User\Documents\AAAAA\eseg2\PRE_VGG16-448^2-Fishv2_CLS_DE_AA+CutMixUp-SGD_1e-4_0,9-CE_wprop-BS8-E73-88,26%_61,70%_61,36%.pth'
ResNet50_Path = r'C:\Users\User\Documents\AAAAA\eseg2\ResNet50-448^2-Fishv2_CLS_DE_AA+CutMixUp-SGD_1e-3_0,9_1e-4-CE_wprop-BS8-E40-39,91%_34,04%_40,91%.pth'
PRE_ResNet50_Path = r'C:\Users\User\Documents\AAAAA\eseg2\PRE_ResNet50-448^2-Fishv2_CLS_DE_AA+CutMixUp-SGD_1e-3_0,9_1e-4-CE_wprop-BS8-E50-98,59%_68,09%_63,64%.pth'

'''Model selection:
UNet
UNet+VGG16
UNet+Pretrained_VGG16
UNet+Pre_CLS_VGG16
UNet+CLS_VGG16
UNet+ResNet50
UNet+Pretrained_ResNet50
UNet+Pre_CLS_ResNet50
UNet+CLS_ResNet50

FCN8+VGG16
FCN8+Pretrained_VGG16
FCN8+Pre_CLS_VGG16
FCN8+Pre_CLS_VGG16+4096
FCN8+CLS_VGG16
FCN8+ResNet50
FCN8+Pretrained_ResNet50
FCN8+Pre_CLS_ResNet50
FCN8+CLS_ResNet50

SegNet16
Pretrained_SegNet16
Pretrained_SegNet16+PTDec
Pre_CLS_SegNet16
CLS_SegNet16

UNet+Alt
UNet+Pretrained_ResNet50+Alt
'''

MODELNAME = 'UNet+Pre_CLS_VGG16'

'''Optimizer selection:
SGD_1e_4_0_99 #UNet -> Good
SGD_3e_4_0_99 #
SGD_3e_4_0_99_3e_5 #
SGD_1e_3_0_99 #FCN+VGG -> Bad/Ok? #FCN+ResNet -> Ok
SGD_1e_3_0_99_5e_4 #FCN+VGG -> Good #FCN+ResNet -> Ok
SGD_1e_3_0_99_2e_4 #FCN+VGG -> ??? #FCN+ResNet -> Ok
SGDPoly_1e_2_0_99
SGD_1e_1_0_9 #SegNet -> Good
SGD_3e_2_0_9
SGD_1e_2_0_9 #Slower training
SGDDiff_1e_4_0_99
SGDExp_1e_2_0_99
Adam_1e_4
Adam_3e_5
'''

OPTIMIZERNAME = 'SGD_1e_4_0_99'

ARGS = ARGNAME.split(":")
print(ARGS)
if not ARGNAME.startswith('C:\\'):
    if ARGS[0] == '--Model':
        MODELNAME = ARGS[1]
    if ARGS[2] == '--Optimizer':
        OPTIMIZERNAME = ARGS[3]

match MODELNAME:
    #########
    # U-Net #
    #########
    case 'UNet':
        model = UNet(3,CLASSES).to(device=DEVICE,dtype=DTYPE)
        model.apply(init_weights)
    case 'UNet+VGG16':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights=None)
        backbone.requires_grad_(True) #Unfreeze
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[6:13],backbone.features[13:23],backbone.features[23:33],backbone.features[33:43]]
        )
        model = UNet(3,CLASSES,convblocks).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        del backbone
    case 'UNet+Pretrained_VGG16':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights='VGG16_BN_Weights.IMAGENET1K_V1')
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[6:13],backbone.features[13:23],backbone.features[23:33],backbone.features[33:43]]
        )
        model = UNet(3,CLASSES,convblocks).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'UNet+Pre_CLS_VGG16':
        backbone = torch.load(PRE_VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[6:13],backbone.features[13:23],backbone.features[23:33],backbone.features[33:43]]
        )
        model = UNet(3,CLASSES,convblocks,channels=[64,128,256,512,512]).to(device=DEVICE,dtype=DTYPE) #################################################################################
        init_lazies(model) #,channels=[64,128,256,512,512]
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'UNet+CLS_VGG16':
        backbone = torch.load(VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[6:13],backbone.features[13:23],backbone.features[23:33],backbone.features[33:43]]
        )
        model = UNet(3,CLASSES,convblocks).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'UNet+ResNet50':
        backbone = torch.hub.load('pytorch/vision', 'resnet50', weights=None)
        backbone.requires_grad_(True) #Unfreeze
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = UNet(3,CLASSES,convblocks,upsample_output=2).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        del backbone
    case 'UNet+Pretrained_ResNet50':
        backbone = torch.hub.load('pytorch/vision', 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V2')
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = UNet(3,CLASSES,convblocks,upsample_output=2).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'UNet+Pre_CLS_ResNet50':
        backbone = torch.load(PRE_ResNet50_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = UNet(3,CLASSES,convblocks,upsample_output=2).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'UNet+CLS_ResNet50':
        backbone = torch.load(ResNet50_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = UNet(3,CLASSES,convblocks,upsample_output=2).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    #######
    # FCN #
    #######
    case 'FCN8+VGG16':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights=None)
        backbone.requires_grad_(True) #Unfreeze
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:7],backbone.features[7:14],backbone.features[14:24],backbone.features[24:34],backbone.features[34:44]]
        )
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        #for m in [model.classifier3,model.classifier4]:
        #    torch.nn.init.zeros_(m.weight) #zero-init scoring layers with a skip
        #    torch.nn.init.zeros_(m.bias)
        del backbone
    case 'FCN8+Pretrained_VGG16':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights='VGG16_BN_Weights.IMAGENET1K_V1')
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:7],backbone.features[7:14],backbone.features[14:24],backbone.features[24:34],backbone.features[34:44]]
        )
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'FCN8+Pre_CLS_VGG16':
        backbone = torch.load(PRE_VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:7],backbone.features[7:14],backbone.features[14:24],backbone.features[24:34],backbone.features[34:44]]
        )
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'FCN8+Pre_CLS_VGG16+4096':
        backbone = torch.load(PRE_VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:7],backbone.features[7:14],backbone.features[14:24],backbone.features[24:34],backbone.features[34:44]]
        )
        convblocks[4].append(torch.nn.Conv2d(512,4096,kernel_size=(3,3),padding=(1,1),bias=False))
        convblocks[4].append(torch.nn.BatchNorm2d(4096))
        convblocks[4].append(torch.nn.ReLU(inplace=True))
        convblocks[4].append(torch.nn.Conv2d(4096,4096,kernel_size=(3,3),padding=(1,1),bias=False))
        convblocks[4].append(torch.nn.BatchNorm2d(4096))
        convblocks[4].append(torch.nn.ReLU(inplace=True))
        convblocks[4].append(torch.nn.Conv2d(4096,CLASSES,kernel_size=(1,1),bias=False))
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'FCN8+CLS_VGG16':
        backbone = torch.load(VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:7],backbone.features[7:14],backbone.features[14:24],backbone.features[24:34],backbone.features[34:44]]
        )
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'FCN8+ResNet50':
        backbone = torch.hub.load('pytorch/vision', 'resnet50', weights=None)
        backbone.requires_grad_(True) #Unfreeze
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        del backbone
    case 'FCN8+Pretrained_ResNet50':
        backbone = torch.hub.load('pytorch/vision', 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V2')
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'FCN8+Pre_CLS_ResNet50':
        backbone = torch.load(PRE_ResNet50_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'FCN8+CLS_ResNet50':
        backbone = torch.load(ResNet50_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = FCNBase(3,CLASSES,convblocks,upsamplevar=8).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    ##########
    # SegNet #
    ##########
    case 'SegNet16':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights=None)
        backbone.requires_grad_(True) #Unfreeze
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[7:13],backbone.features[14:23],backbone.features[24:33],backbone.features[34:43],]
        )
        model = SegNet(3,CLASSES,convblocks=convblocks).to(device=DEVICE,dtype=DTYPE)
        model.apply(init_weights)
        del backbone
    case 'Pretrained_SegNet16':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights='VGG16_BN_Weights.IMAGENET1K_V1')
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[7:13],backbone.features[14:23],backbone.features[24:33],backbone.features[34:43],]
        )
        model = SegNet(3,CLASSES,convblocks=convblocks).to(device=DEVICE,dtype=DTYPE)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'Pretrained_SegNet16+PTDec':
        backbone = torch.hub.load('pytorch/vision', 'vgg16_bn', weights='VGG16_BN_Weights.IMAGENET1K_V1')
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[7:13],backbone.features[14:23],backbone.features[24:33],backbone.features[34:43],]
        )
        model = SegNet(3,CLASSES,convblocks=convblocks,ptdec=True).to(device=DEVICE,dtype=DTYPE)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'Pre_CLS_SegNet16':
        backbone = torch.load(PRE_VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[7:13],backbone.features[14:23],backbone.features[24:33],backbone.features[34:43],]
        )
        model = SegNet(3,CLASSES,convblocks=convblocks).to(device=DEVICE,dtype=DTYPE)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case 'CLS_SegNet16':
        backbone = torch.load(VGG16_Path)
        backbone.requires_grad_(False) #Freeze the pretrained weights before applying init
        convblocks = torch.nn.ModuleList(
            [backbone.features[0:6],backbone.features[7:13],backbone.features[14:23],backbone.features[24:33],backbone.features[34:43],]
        )
        model = SegNet(3,CLASSES,convblocks=convblocks).to(device=DEVICE,dtype=DTYPE)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    #######
    # ??? #
    #######
    case 'UNet+Alt':
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=True).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        #model.apply(init_weights)
        #del backbone
    case 'UNet+Pretrained_ResNet50+Alt':
        backbone = torch.hub.load('pytorch/vision:v0.10.0', 'fcn_resnet50', weights='FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1').backbone
        #backbone.conv1.stride = (1,1)
        convblocks = torch.nn.ModuleList([
            torch.nn.Sequential(backbone.conv1,backbone.bn1,backbone.relu),
            torch.nn.Sequential(backbone.maxpool,backbone.layer1),
            backbone.layer2,
            backbone.layer3,
            backbone.layer4
        ])
        model = UNet(3,CLASSES,convblocks,upsample_output=2).to(device=DEVICE,dtype=DTYPE)
        init_lazies(model)
        model.apply(init_weights)
        backbone.requires_grad_(True) #Unfreeze
        del backbone
    case _:
        assert(False)

match OPTIMIZERNAME:
    case 'SGD_1e_4_0_99':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.99, weight_decay=0) #Default
    case 'SGD_3e_4_0_99':
        optimizer = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.99, weight_decay=0) #
    case 'SGD_3e_4_0_99_3e_5':
        optimizer = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.99, weight_decay=3e-5) #
    case 'SGD_1e_3_0_99':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.99, weight_decay=0) #
    case 'SGD_1e_3_0_99_5e_4':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.99, weight_decay=5e-4) #
    case 'SGD_1e_3_0_99_2e_4':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.99, weight_decay=2e-4) #
    case 'SGDPoly_1e_2_0_99':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.99, weight_decay=0) #Default
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.316227766, patience=25, threshold=0, cooldown=0, min_lr=1e-4)
        scheduler.cooldown_counter = 25
    case 'SGD_1e_1_0_9':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=0) #Default
    case 'SGD_3e_2_0_9':
        optimizer = torch.optim.SGD(model.parameters(), lr=3e-2, momentum=0.9, weight_decay=0) #
    case 'SGD_1e_2_0_9':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=0) #
    case 'SGDDiff_1e_4_0_99':
        encoder_params = set(model.encoderblocks.parameters())
        not_encoder_params = list(set(model.parameters()) - encoder_params)
        encoder_params = list(encoder_params)
        optimizer = torch.optim.SGD(
            [
                {"params": not_encoder_params},
                {"params": encoder_params, "lr": 1e-5}
            ],
            lr=1e-4, momentum=0.99, weight_decay=0
        )
    case 'SGDExp_1e_2_0_99':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.99, weight_decay=0) #Default
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,  gamma=0.977237221, verbose=True) #0.1^(1/100)
    case 'Adam_1e_4':
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
    case 'Adam_3e_5':
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-5, betas=(0.9, 0.999))
    case _:
        assert(False)

antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=50,grace_period=50)

In [None]:
print(MODELNAME)
print(sum(p.numel() for p in model.parameters()))
print(sum(p.numel() for p in model.parameters() if p.requires_grad))


In [None]:
ig = ImageGenerator(test_img_dir,size=(HEIGHT,WIDTH))
image, mask = ig.__getitem__(4)
print()
print(torch.nn.functional.one_hot(mask.long(),2).dtype)
print(mask.shape)
display([{'image':image,'mask':mask}])

In [None]:
train_ds = ImageGenerator(train_img_dir,size=(HEIGHT,WIDTH))
val_ds = ImageGenerator(val_img_dir,size=(HEIGHT,WIDTH))
test_ds = ImageGenerator(test_img_dir,size=(HEIGHT,WIDTH))

PREPREPROCESSOR = 'MosaicMix' #MosaicMix is better than None and PureMosaic

match PREPREPROCESSOR:
    case 'None':
        pretransform = torchvision.transforms.v2.ToDtype(torch.float32, True)
    case 'MosaicMix':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            torchvision.transforms.v2.RandomChoice([
                torchvision.transforms.v2.RandomAffine(0),
                Mosaic()
            ])
        ])
    case 'PureMosaic':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            Mosaic()
        ])
    case 'CutMix':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            torchvision.transforms.v2.RandomChoice([
                torchvision.transforms.v2.RandomAffine(0),
                CutMixMask()
            ])
        ])
    case 'PureCut':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            CutMixMask()
        ])
    case 'MosaicCutMix':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            torchvision.transforms.v2.RandomChoice([
                Mosaic(),
                CutMixMask()
            ])
        ])
    case _:
        assert(False)
    
def collate_fn(batch):
    return pretransform(*torch.utils.data.default_collate(batch))

if DATASETNAME == "FishV3":
    train_dataloader = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,pin_memory=True)
else:
    train_dataloader = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,pin_memory=True,collate_fn=collate_fn,drop_last=True)
val_dataloader = torch.utils.data.DataLoader(val_ds,batch_size=BATCH_SIZE,pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(test_ds,batch_size=BATCH_SIZE)

In [None]:
transforms = torchvision.transforms.v2.Compose([
    AddGaussianNoise(0,0.0125),
    torchvision.transforms.v2.RandomChoice([
        torchvision.transforms.v2.RandomAffine(0),
        torchvision.transforms.v2.GaussianBlur(5),
        torchvision.transforms.v2.GaussianBlur(9),
    ]),
    torchvision.transforms.v2.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.v2.RandomVerticalFlip(p=0.5),
    torchvision.transforms.v2.ColorJitter(0.25,0.0,0.25),
    torchvision.transforms.v2.RandomChoice([
        torchvision.transforms.v2.RandomAffine(0),
        torchvision.transforms.v2.ElasticTransform(alpha=768, sigma=14),
        torchvision.transforms.v2.ElasticTransform(alpha=2048, sigma=20)
    ]),
    torchvision.transforms.v2.RandomRotation((-15,15))
]).to(DEVICE)

loss_fn = FocalTversky_loss({'apply_nonlin':torch.nn.Softmax(dim=1)}) #gamma from 0.333... to 1, default is 0.75
LOSSNAME = 'FocalTversky'


model_manager = Manager(train_dataloader, val_dataloader, model, loss_fn, optimizer=optimizer, preprocesser=transforms,scheduler=scheduler,multi_losses=AUXLOSSES)

In [None]:
model_history = None
antilatestopper.reset()
model_history = model_manager.train(EPOCHS, antilatestopper, restore_best_weights = True)

In [None]:
if model_history == None:
    print("Missing history, restoring...")
    model_history = model_manager.history 

# Results

In [None]:
del train_dataloader
train_dataloader = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE)
del val_dataloader
val_dataloader = torch.utils.data.DataLoader(val_ds,batch_size=BATCH_SIZE)
train_metric = model_manager.evaluate(train_dataloader)
val_metric = model_manager.evaluate(val_dataloader)
test_metric = model_manager.evaluate(test_dataloader)

print(f"Training Error: mIoU: {(100*train_metric['miou']):>.3f}%, F1: {(100*train_metric['f1']):>.3f}%, Avg loss: {train_metric['loss']:>8f}")
print(f"Validating Error: mIoU: {(100*val_metric['miou']):>.3f}%, F1: {(100*val_metric['f1']):>.3f}%, Avg loss: {val_metric['loss']:>8f}")
print(f"Testing Error: mIoU: {(100*test_metric['miou']):>.3f}%, F1: {(100*test_metric['f1']):>.3f}%, Avg loss: {test_metric['loss']:>8f}")

In [None]:
best_metrics = ('{:.2%}'.format(train_metric['miou'])+'_'+'{:.2%}'.format(val_metric['miou'])+'_'+'{:.2%}'.format(test_metric['miou'])).replace('.',',')
size = str(WIDTH)+'^2' if WIDTH == HEIGHT else str(WIDTH)+'×'+str(HEIGHT)
precision_name = "-half" if next(model.parameters()).dtype == torch.float16 else "-bhalf" if next(model.parameters()).dtype == torch.bfloat16 else ''
NAME = MODELNAME+precision_name+'-'+size+'-'+DATASETNAME+"+"+PREPREPROCESSOR+'-'+OPTIMIZERNAME+'-'+LOSSNAME+'-BS'+str(BATCH_SIZE) \
+'-E'+str(model_history['best_epoch'])+'-'+best_metrics
print(NAME)

In [None]:
ig = ImageGenerator(train_img_dir,size=(HEIGHT,WIDTH))
display_list = []
for i in range(len(ig)):
    image, mask = ig.__getitem__(i)
    prediction = model_manager.predict(image).detach().argmax(0).cpu()
    display_list.append({'image':image, 'mask':mask, 'pred':prediction})
display(display_list,NAME+'-TrainResult')

ig = ImageGenerator(val_img_dir,size=(HEIGHT,WIDTH))
display_list = []
for i in range(len(ig)):
    image, mask = ig.__getitem__(i)
    prediction = model_manager.predict(image).detach().argmax(0).cpu()
    display_list.append({'image':image, 'mask':mask, 'pred':prediction})
display(display_list,NAME+'-ValResult')

ig = ImageGenerator(test_img_dir,size=(HEIGHT,WIDTH))
display_list = []
for i in range(len(ig)):
    image, mask = ig.__getitem__(i)
    prediction = model_manager.predict(image).detach().argmax(0).cpu()
    display_list.append({'image':image, 'mask':mask, 'pred':prediction})
display(display_list,NAME+'-TestResult')

In [None]:
train_loss = model_history['train_loss']
val_loss = model_history['val_loss']
train_miou = model_history['train_miou']
val_miou = model_history['val_miou']
train_f1 = model_history['train_f1']
val_f1 = model_history['val_f1']
best_epoch = model_history['best_epoch']
plt.figure(figsize=(21, 7))
plt.subplot(1, 3, 1)
plt.plot(range(len(train_loss)), train_loss, '-r', label='Training')
plt.plot(range(len(val_loss)), val_loss, '-b', label='Validation')
plt.axvline(x=best_epoch)
plt.title(LOSSNAME+' Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
if type(loss_fn) is DC_and_CE_loss:
    plt.ylim([-1, 1])
elif type(loss_fn) is SoftDiceLoss or (type(loss_fn) is TverskyLoss and type(loss_fn) is not FocalTversky_loss):
    plt.ylim([-1, 0])
else:
    plt.ylim([0, 1])
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(range(len(train_miou)), train_miou, '-r', label='Training')
plt.plot(range(len(val_miou)), val_miou, '-b', label='Validation')
plt.axvline(x=best_epoch)
plt.title('Mean Intersection over Union')
plt.xlabel('Epoch')
plt.ylabel('mIoU')
plt.ylim([0, 1])
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(range(len(train_f1)), train_f1, '-r', label='Training')
plt.plot(range(len(val_f1)), val_f1, '-b', label='Validation')
plt.axvline(x=best_epoch)
plt.title('Mean F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1')
plt.ylim([0, 1])
plt.legend()
plt.savefig(NAME+'-TrainPlot')
plt.show()

In [None]:
if SAVEMODEL:
    torch.save(model, NAME+'.pth')

In [None]:
with open("Log.txt", "a") as f:
    f.write('Name: ')
    f.write(NAME)
    f.write('\n')
    f.write(str(loss_fn))
    f.write('\n')
    f.write(str(optimizer))
    f.write('\n')
    f.write(str(transforms))
    f.write('\n')
    f.write('-------------------------------\n')
    f.write('\n')
with open("Table.txt", "a") as f:
    f.write(MODELNAME+'\t'+str(WIDTH)+'\t'+str(HEIGHT)+'\t'+DATASETNAME+'\t'+PREPREPROCESSOR+'\t'+OPTIMIZERNAME+'\t'+LOSSNAME+'\t'+str(BATCH_SIZE) \
            +'\t'+str(model_history['best_epoch'])+'\t'+str(train_metric['miou'])+'\t'+str(val_metric['miou'])+'\t'+str(test_metric['miou']))
    f.write('\n')
with open("TrainLosses.txt", "a") as f:
    for i in train_loss:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
with open("ValLosses.txt", "a") as f:
    for i in val_loss:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
with open("TrainMIoU.txt", "a") as f:
    for i in train_miou:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
with open("ValMIoU.txt", "a") as f:
    for i in val_miou:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
with open("TrainF1.txt", "a") as f:
    for i in train_f1:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
with open("ValF1.txt", "a") as f:
    for i in val_f1:
        f.write(str(i))
        f.write('\t')
    f.write('\n')

In [None]:
colormap = ['-r']*8+[None]*8+['-g']*8+['-b']*8

plt.figure(figsize=(18, 18))
with open("ValLosses.txt", "r", newline='\r\n') as f:
    index = 0
    linestr = f.readline()
    while len(linestr) != 0:
        line = linestr.split("\t")
        line.remove("\r\n")
        data = list(map(lambda x: float(x),line))
        if index < len(colormap) and colormap[index] != None:
            plt.plot(range(len(data)), data, colormap[index], label='XXX')
        index += 1
        linestr = f.readline()
plt.savefig('TrainPlot-seg')
plt.show()