In [None]:
import numpy as np
from matplotlib import pyplot as plt
import random
import torch
import torchvision
from torchvision.transforms import v2
import copy
if torch.cuda.is_available(): 
 dev = "cuda:0" 
else: 
 dev = "cpu" 
DEVICE = torch.device(dev)
print(torch.__version__)

In [None]:
class ImageGenerator(torch.utils.data.Dataset):
    def __init__(self, imgdir, size=None):
        """ Returns a set of frames with their associated label. 

          Args:
            data: Data containing video file paths.
        """
        self.imgdir = imgdir
        import os
        self.imagefns = []
        labeldict = {v: k for k, v in enumerate(os.listdir(self.imgdir))}
        self.labels = []
        for i in os.listdir(self.imgdir):
            for j in os.listdir(self.imgdir+'\\'+i):
                self.imagefns.append(i+'\\'+j)
                self.labels.append(labeldict[i])
        assert(len(self.imagefns)==len(self.labels))
        self.resizer = None
        if size != None:
            self.resizer = torchvision.transforms.Resize(size,torchvision.transforms.InterpolationMode.BICUBIC,antialias=False)
    def __len__(self):
        return len(self.imagefns)
    def __getitem__(self, idx):
        filename = self.imgdir+'\\'+self.imagefns[idx]
        label = self.labels[idx]
        image = torchvision.io.read_image(filename,torchvision.io.ImageReadMode.RGB)
        if self.resizer != None:
            image = self.resizer(image)
        return image, label
def display(display, label, save=None):
    plt.figure(figsize=(15, 15))
    title = label
    plt.subplot(1, 1, 1)
    plt.title(label)
    if display.dim() == 3:
        plt.imshow(display.permute(1,2,0)) #pytorch loves channels first but matplotlib loves channels last
    else:
        plt.imshow(display)
    plt.axis('off')
    if save!=None:
        plt.savefig(save) #save before show
    plt.show()

# Networks

In [None]:
def ConvBlock(num,channels,filters,addpool=False,reversed=False,dropout=0):
    block = torch.nn.Sequential()
    if addpool:
        block.append(torch.nn.MaxPool2d(2, stride=2))
    if reversed:
        for i in range(num-1):
            block.append(torch.nn.Conv2d(channels,channels,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
            block.append(torch.nn.BatchNorm2d(channels))
            block.append(torch.nn.ReLU())
        block.append(torch.nn.Conv2d(channels,filters,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
        block.append(torch.nn.BatchNorm2d(filters))
        block.append(torch.nn.ReLU())
    else:
        for i in range(num):
            if channels == None:
                block.append(torch.nn.LazyConv2d(filters,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
            else:
                block.append(torch.nn.Conv2d(channels,filters,kernel_size=(3,3),padding='same',bias=False)) #Bias is pointless for BatchNorm
            block.append(torch.nn.BatchNorm2d(filters))
            block.append(torch.nn.ReLU())
            channels = filters
    if dropout > 0:
        block.append(torch.nn.Dropout2d(dropout))
    return block
class BottleResBlock(torch.nn.Module):
    def __init__(self,channels,bottlechan,filters,stride=1,dilation=1):
        super().__init__()
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(channels,bottlechan,kernel_size=(1,1),padding='valid',bias=False), #Bias is pointless for BatchNorm
            torch.nn.BatchNorm2d(bottlechan),
            torch.nn.ReLU(),
            torch.nn.Conv2d(bottlechan,bottlechan,kernel_size=(3,3),padding=1+dilation-1,stride=stride,dilation=dilation,bias=False),
            torch.nn.BatchNorm2d(bottlechan),
            torch.nn.ReLU(),
            torch.nn.Conv2d(bottlechan,filters,kernel_size=(1,1),padding='valid',bias=False),
            torch.nn.BatchNorm2d(filters)
        )
        self.skip = torch.nn.Sequential(
            torch.nn.Conv2d(channels,filters,kernel_size=(1,1),padding='valid',stride=stride,bias=False),
            torch.nn.BatchNorm2d(filters),
        ) if channels != filters or stride != 1 else torch.nn.Identity()
    def forward(self,x):
        x_1 = self.skip(x)
        x_2 = self.bottleneck(x)
        y = torch.add(x_1,x_2)
        return y
class InvertedBottleResBlock(torch.nn.Module):
    def __init__(self,channels,expansion,filters,stride=1):
        super().__init__()
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(channels,channels*expansion,kernel_size=(1,1),padding='valid',bias=False), #Bias is pointless for BatchNorm
            torch.nn.BatchNorm2d(channels*expansion),
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels*expansion,channels*expansion,kernel_size=(3,3),padding=1,stride=stride,groups=channels*expansion,bias=False), #DWConv
            torch.nn.BatchNorm2d(channels*expansion),
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels*expansion,filters,kernel_size=(1,1),padding='valid',bias=False),
            torch.nn.BatchNorm2d(filters) #No ReLU here
        )
        self.skip = stride == 1 and channels == filters
    def forward(self,x):
        if self.skip:
            x_1 = self.bottleneck(x)
            y = torch.add(x_1,x)
            return y
        else:
            y = self.bottleneck(x)
            return y
class PPM(torch.nn.Module):
    def __init__(self,in_channels,pools=[1,2,4,8]):
        super().__init__()
        def poollayer(size):
            block = torch.nn.Sequential()
            if size != 1:
                block.append(torch.nn.AvgPool2d(size, stride=size))
            block.append(torch.nn.Conv2d(in_channels,in_channels//4,kernel_size=(1,1),bias=False))
            block.append(torch.nn.BatchNorm2d(in_channels//4))
            block.append(torch.nn.ReLU())
            if size != 1:
                block.append(torch.nn.Upsample(scale_factor=size, mode='bilinear'))
            return block
        self.pools = torch.nn.ModuleList([poollayer(x) for x in pools])
        
        #self.pool1 = poollayer(pools[0])
        #self.pool2 = poollayer(pools[1])
        #self.pool3 = poollayer(pools[2])
        #self.pool4 = poollayer(pools[3])
    def forward(self, x):
        x_pooled = [pool(x) for pool in self.pools]
        x_pooled.append(x)
        y = torch.cat(x_pooled,1)
        #x1 = self.pool1(x)
        #x2 = self.pool2(x)
        #x3 = self.pool3(x)
        #x4 = self.pool4(x)
        
        #y = torch.cat((x,x1,x2,x3,x4),1)
        return y
class ASPP(torch.nn.Module):
    def __init__(self,in_channels,out_channel_per_branch,dilations=[6,12,18]):
        super().__init__()
        def dilatedlayer(rate):
            block = torch.nn.Sequential()
            #PADDING HAS TO BE LIKE THIS
            block.append(torch.nn.Conv2d(in_channels,out_channel_per_branch,kernel_size=(3,3),padding=rate,dilation=rate,bias=False))
            block.append(torch.nn.BatchNorm2d(out_channel_per_branch))
            block.append(torch.nn.ReLU())
            return block
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,out_channel_per_branch,kernel_size=(1,1),bias=False),
            torch.nn.BatchNorm2d(out_channel_per_branch),
            torch.nn.ReLU()
        )
        self.dilations = torch.nn.ModuleList([dilatedlayer(x) for x in dilations])
        self.globalpool = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool2d(1),
            torch.nn.Conv2d(in_channels,out_channel_per_branch,kernel_size=(1,1),bias=False),
            torch.nn.BatchNorm2d(out_channel_per_branch),
            torch.nn.ReLU()
        )
        self.finalconv = torch.nn.Sequential(
            torch.nn.Conv2d(out_channel_per_branch*(2+len(self.dilations)),out_channel_per_branch,kernel_size=(1,1),bias=False),
            torch.nn.BatchNorm2d(out_channel_per_branch),
            torch.nn.ReLU()
        )
    def forward(self, x):
        input_spatial_dim = x.size()[2:]
        x_pooled = [pool(x) for pool in self.dilations]
        x_pooled.append(self.conv1(x))
        x_pooled.append(torch.nn.functional.interpolate(self.globalpool(x), input_spatial_dim, mode='bilinear', align_corners=True))
        y = self.finalconv(torch.cat(x_pooled,1))
        return y
class ResNet(torch.nn.Module):
    def __init__(self,in_channels,classes,layers=[3,4,6,3]):
        super().__init__()
        def ResNetBlock(num,channels,bottlechan,filters,stride=1,addpool=False):
            blocks = torch.nn.Sequential()
            if addpool:
                blocks.append(torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
            for i in range(num):
                if stride != 1 and i == 0:
                    blocks.append(BottleResBlock(channels,bottlechan,filters,stride))
                else:
                    blocks.append(BottleResBlock(channels,bottlechan,filters))
                channels = filters
            return blocks
        self.blocks = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv2d(in_channels,64,kernel_size=(7,7),stride=2,padding=3,bias=False),
                torch.nn.BatchNorm2d(64),
                torch.nn.ReLU(),
            ),
            ResNetBlock(layers[0],64,64,256,addpool=True),
            ResNetBlock(layers[1],256,128,512,stride=2),
            ResNetBlock(layers[2],512,256,1024,stride=2),
            ResNetBlock(layers[3],1024,512,2048,stride=2),
        ])
        self.pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = torch.nn.Flatten(start_dim=1, end_dim=-1)
        self.classifier = torch.nn.Linear(2048,classes)
    def forward(self, x):
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
        x = self.pooling(x)
        x = self.flatten(x)
        y = self.classifier(x)
        return y
class VGG16(torch.nn.Module):
    def __init__(self,in_channels,classes):
        super().__init__()
        channels = [in_channels,64,128,256,512,512]
        self.blocks = torch.nn.ModuleList([
            ConvBlock(2,channels[0],channels[1]),
            ConvBlock(2,channels[1],channels[2]),
            ConvBlock(3,channels[2],channels[3]),
            ConvBlock(3,channels[3],channels[4]),
            ConvBlock(3,channels[4],channels[5])
        ])
        self.pools = torch.nn.ModuleList([
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.MaxPool2d(2, stride=2)
        ])
        self.avgpool = torch.nn.AdaptiveAvgPool2d((7,7))
        self.flatten = torch.nn.Flatten(start_dim=1, end_dim=-1)
        self.fcs = torch.nn.ModuleList([
            torch.nn.Linear(7*7*channels[5],4096,bias=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(4096,4096,bias=True),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
        ])
        self.classifier = torch.nn.Linear(4096,classes)
    def forward(self, x):
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
            x = self.pools[i](x)
        x = self.avgpool(x)
        x = self.flatten(x)
        for i in range(len(self.fcs)):
            x = self.fcs[i](x)
        y = self.classifier(x)
        return y
class VGG13(VGG16):
    def __init__(self,in_channels,classes):
        super().__init__(in_channels,classes)
        channels = [in_channels,64,128,256,512,512]
        self.blocks = torch.nn.ModuleList([
            ConvBlock(2,channels[0],channels[1]),
            ConvBlock(2,channels[1],channels[2]),
            ConvBlock(2,channels[2],channels[3]),
            ConvBlock(2,channels[3],channels[4]),
            ConvBlock(2,channels[4],channels[5])
        ])
class VGG19(VGG16):
    def __init__(self,in_channels,classes):
        super().__init__(in_channels,classes)
        channels = [in_channels,64,128,256,512,512]
        self.blocks = torch.nn.ModuleList([
            ConvBlock(2,channels[0],channels[1]),
            ConvBlock(2,channels[1],channels[2]),
            ConvBlock(4,channels[2],channels[3]),
            ConvBlock(4,channels[3],channels[4]),
            ConvBlock(4,channels[4],channels[5])
        ])
class UNet(torch.nn.Module):
    def __init__(self,in_channels,classes,convblocks=None,encoderblocks=[2,2,2,2,2],decoderblocks=[2,2,2,2],upsample_output=1):
        super().__init__()
        self.channels = [64,128,256,512,1024]
        if type(convblocks) is torch.nn.ModuleList:
            self.encoderblocks = torch.nn.ModuleList([
                convblocks[0],
                convblocks[1],
                convblocks[2],
                convblocks[3],
                convblocks[4]
            ])
            self.upsamplers = torch.nn.ModuleList([
                torch.nn.Sequential(
                    torch.nn.LazyConvTranspose2d(self.channels[3],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[3],self.channels[2],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[2],self.channels[1],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[1],self.channels[0],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                )
            ])
            self.decoderblocks = torch.nn.ModuleList([
                ConvBlock(decoderblocks[0],self.channels[4],self.channels[3]),
                ConvBlock(decoderblocks[1],self.channels[3],self.channels[2]),
                ConvBlock(decoderblocks[2],self.channels[2],self.channels[1]),
                ConvBlock(decoderblocks[3],self.channels[1],self.channels[0])
            ])
        else:
            self.encoderblocks = torch.nn.ModuleList([
                ConvBlock(encoderblocks[0],in_channels,self.channels[0]),
                ConvBlock(encoderblocks[1],self.channels[0],self.channels[1],True),
                ConvBlock(encoderblocks[2],self.channels[1],self.channels[2],True),
                ConvBlock(encoderblocks[3],self.channels[2],self.channels[3],True),
                ConvBlock(encoderblocks[4],self.channels[3],self.channels[4],True)
            ])
            self.upsamplers = torch.nn.ModuleList([
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[4],self.channels[3],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[3],self.channels[2],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[2],self.channels[1],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                ),
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(self.channels[1],self.channels[0],kernel_size=(2,2),stride=(2,2)),
                    torch.nn.ReLU()
                )
            ])
            self.decoderblocks = torch.nn.ModuleList([
                ConvBlock(decoderblocks[0],self.channels[4],self.channels[3]),
                ConvBlock(decoderblocks[1],self.channels[3],self.channels[2]),
                ConvBlock(decoderblocks[2],self.channels[2],self.channels[1]),
                ConvBlock(decoderblocks[3],self.channels[1],self.channels[0])
            ])
        self.classifier = torch.nn.Conv2d(self.channels[0],classes,kernel_size=(1,1),padding='same')
        self.upsampler = torch.nn.Upsample(scale_factor=upsample_output, mode='bilinear') if upsample_output != 1 else torch.nn.Identity()
    def forward(self, x):
        x1 = self.encoderblocks[0](x)
        x2 = self.encoderblocks[1](x1)
        x3 = self.encoderblocks[2](x2)
        x4 = self.encoderblocks[3](x3)
        x5 = self.encoderblocks[4](x4)
        x4_2 = self.upsamplers[0](x5)
        x4_2 = torch.cat((x4,x4_2),1)
        x4_2 = self.decoderblocks[0](x4_2)
        x3_2 = self.upsamplers[1](x4_2)
        x3_2 = torch.cat((x3,x3_2),1)
        x3_2 = self.decoderblocks[1](x3_2)
        x2_2 = self.upsamplers[2](x3_2)
        x2_2 = torch.cat((x2,x2_2),1)
        x2_2 = self.decoderblocks[2](x2_2)
        x1_2 = self.upsamplers[3](x2_2)
        x1_2 = torch.cat((x1,x1_2),1)
        x1_2 = self.decoderblocks[3](x1_2)
        y = self.classifier(x1_2)
        y = self.upsampler(y)
        return y
class UNetPPM(UNet):
    def __init__(self,in_channels,classes,convblocks=None,encoderblocks=[2,2,2,2,2],decoderblocks=[2,2,2,2],upsample_output=1,pools=[1,2,4,8]):
        super().__init__(in_channels,classes,convblocks,encoderblocks,decoderblocks,upsample_output)
        self.decoderblocks[3] = torch.nn.Sequential(
            PPM(self.channels[1],pools),
            ConvBlock(decoderblocks[3],(self.channels[1]//4)*len(pools)+self.channels[1],self.channels[0]),
        )
        #self.classifier = torch.nn.Conv2d(self.channels[0],CLASSES,kernel_size=(1,1),padding='same')
class UNetASPP(UNet):
    def __init__(self,in_channels,classes,convblocks=None,encoderblocks=[2,2,2,2,2],decoderblocks=[2,2,2,2],upsample_output=1,dilations=[6,12,18]):
        super().__init__(in_channels,classes,convblocks,encoderblocks,decoderblocks,upsample_output)
        self.decoderblocks[3] = torch.nn.Sequential(
            #ASPP(self.channels[1],self.channels[0],dilations),
            ConvBlock(decoderblocks[3],self.channels[1],self.channels[0]),
            ASPP(self.channels[0],self.channels[0],dilations),
        )
class SegNet(torch.nn.Module):
    def __init__(self,in_channels,classes,convblocks=None,blocks=[2,2,3,3,3]):
        super().__init__()
        self.channels = [64,128,256,512,512]
        if type(convblocks) is torch.nn.ModuleList:
            self.encoderblocks = torch.nn.ModuleList([
                convblocks[0],
                convblocks[1],
                convblocks[2],
                convblocks[3],
                convblocks[4]
            ])
        else:
            self.encoderblocks = torch.nn.ModuleList([
                ConvBlock(blocks[0],in_channels,self.channels[0]),
                ConvBlock(blocks[1],self.channels[0],self.channels[1]),
                ConvBlock(blocks[2],self.channels[1],self.channels[2]),
                ConvBlock(blocks[3],self.channels[2],self.channels[3]),
                ConvBlock(blocks[4],self.channels[3],self.channels[4])
            ])
        self.downsamplers = torch.nn.ModuleList([
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True),
            torch.nn.MaxPool2d(2, stride=2,return_indices=True)
        ])
        self.upsamplers = torch.nn.ModuleList([
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2),
            torch.nn.MaxUnpool2d(kernel_size=2, stride=2)
        ])
        self.decoderblocks = torch.nn.ModuleList([
            ConvBlock(blocks[4],self.channels[4],self.channels[3],reversed=True),
            ConvBlock(blocks[3],self.channels[3],self.channels[2],reversed=True),
            ConvBlock(blocks[2],self.channels[2],self.channels[1],reversed=True),
            ConvBlock(blocks[1],self.channels[1],self.channels[0],reversed=True),
            ConvBlock(blocks[0],self.channels[0],classes,reversed=True)
        ])
    def forward(self, x):
        x = self.encoderblocks[0](x)
        x, i1 = self.downsamplers[0](x)
        x = self.encoderblocks[1](x)
        x, i2 = self.downsamplers[0](x)
        x = self.encoderblocks[2](x)
        x, i3 = self.downsamplers[0](x)
        x = self.encoderblocks[3](x)
        x, i4 = self.downsamplers[0](x)
        x = self.encoderblocks[4](x)
        x, i5 = self.downsamplers[0](x)
        x = self.upsamplers[0](x, indices=i5)
        x = self.decoderblocks[0](x)
        x = self.upsamplers[1](x, indices=i4)
        x = self.decoderblocks[1](x)
        x = self.upsamplers[2](x, indices=i3)
        x = self.decoderblocks[2](x)
        x = self.upsamplers[3](x, indices=i2)
        x = self.decoderblocks[3](x)
        x = self.upsamplers[4](x, indices=i1)
        x = self.decoderblocks[4](x)
        return x
class FastSCNN(torch.nn.Module):
    def __init__(self,in_channels,classes,backbonemodel=None):
        super().__init__()
        self.downsample = torch.nn.Sequential(
            torch.nn.Conv2d(3,32,kernel_size=(3,3),padding=1,stride=2,bias=False), #Bias is pointless for BatchNorm
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32,32,kernel_size=(3,3),padding=1,groups=32,stride=2,bias=False), #Depthwise Conv
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32,48,kernel_size=(1,1),padding='same',bias=False), #Pointwise Conv
            torch.nn.BatchNorm2d(48),
            torch.nn.ReLU(),
            torch.nn.Conv2d(48,48,kernel_size=(3,3),padding=1,groups=48,stride=2,bias=False), #Depthwise Conv
            torch.nn.BatchNorm2d(48),
            torch.nn.ReLU(),
            torch.nn.Conv2d(48,64,kernel_size=(1,1),padding='same',bias=False), #Pointwise Conv
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
        )
        self.bottlenecks = torch.nn.Sequential(
            InvertedBottleResBlock(64,6,64,2),
            InvertedBottleResBlock(64,6,64,1),
            InvertedBottleResBlock(64,6,64,1),
            InvertedBottleResBlock(64,6,96,2),
            InvertedBottleResBlock(96,6,96,1),
            InvertedBottleResBlock(96,6,96,1),
            InvertedBottleResBlock(96,6,128,1),
            InvertedBottleResBlock(128,6,128,1),
            InvertedBottleResBlock(128,6,128,1)
        )
        self.ppm = PPM(128) #outchannel is 256
        self.fuse1 = torch.nn.Sequential(
            torch.nn.Conv2d(64,128,kernel_size=(1,1),padding='valid',bias=False), #Bias is pointless for BatchNorm
            torch.nn.BatchNorm2d(128)
        )
        self.fuse2 = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=4, mode='bilinear'),
            torch.nn.Conv2d(256,256,kernel_size=(3,3),padding='same',groups=256,bias=False), #DWConv
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256,128,kernel_size=(1,1),padding='valid',bias=False),
            torch.nn.BatchNorm2d(128)
        )
        self.fuse = torch.nn.Sequential(
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU()
        )
        self.last = torch.nn.Sequential(
            torch.nn.Conv2d(128,128,kernel_size=(3,3),padding='same',groups=128,bias=False), #Depthwise Conv
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,kernel_size=(1,1),padding='same',bias=False), #Pointwise Conv
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,kernel_size=(3,3),padding='same',groups=128,bias=False), #Depthwise Conv
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,kernel_size=(1,1),padding='same',bias=False), #Pointwise Conv
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,classes,kernel_size=(1,1),padding='same',bias=False), #Pointwise Conv
            torch.nn.Upsample(scale_factor=8, mode='bilinear')
        )
    def forward(self,x):
        x = self.downsample(x)
        x2 = self.bottlenecks(x)
        x2 = self.ppm(x2)
        x1 = self.fuse1(x)
        x2 = self.fuse2(x2)
        x = torch.add(x1,x2)
        x = self.fuse(x)
        x = self.last(x)
        return x
def INF(B,H,W):
    return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(torch.nn.Module):
    """ Criss-Cross Attention Module, copied straight from the repo"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        
        self.query_conv = torch.nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = torch.nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = torch.nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = torch.nn.Softmax(dim=3)
        self.INF = INF
        self.gamma = torch.nn.Parameter(torch.zeros(1))
    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H)
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x
class RCCAModule(torch.nn.Module):
    """Also copied too"""
    def __init__(self, in_channels, out_channels, num_classes):
        super(RCCAModule, self).__init__()
        inter_channels = in_channels // 4
        self.conva = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            torch.nn.BatchNorm2d(inter_channels),
            torch.nn.ReLU(),
        )
        self.cca = CrissCrossAttention(inter_channels)
        self.convb = torch.nn.Sequential(
            torch.nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            torch.nn.BatchNorm2d(inter_channels),
            torch.nn.ReLU(),
        )
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels+inter_channels, out_channels, kernel_size=3, padding=1, dilation=1, bias=False),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout2d(0.1),
            torch.nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
        )
    def forward(self, x, recurrence=1):
        output = self.conva(x)
        for i in range(recurrence):
            output = self.cca(output)
        output = self.convb(output)
        output = self.bottleneck(torch.cat([x, output], 1))
        return output
class DilatedResNet(torch.nn.Module):
    def __init__(self,in_channels,classes,layers=[3,4,6,3]):
        super().__init__()
        def ResNetBlock(num,channels,bottlechan,filters,stride=1,dilation=1,addpool=False):
            blocks = torch.nn.Sequential()
            if addpool:
                blocks.append(torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
            for i in range(num):
                if stride != 1 and i == 0:
                    blocks.append(BottleResBlock(channels,bottlechan,filters,stride,dilation=dilation))
                else:
                    blocks.append(BottleResBlock(channels,bottlechan,filters,dilation=dilation))
                channels = filters
            return blocks
        self.blocks = torch.nn.Sequential(
            torch.nn.Sequential(
                torch.nn.Conv2d(in_channels,64,kernel_size=(7,7),stride=2,padding=3,bias=False),
                torch.nn.BatchNorm2d(64),
                torch.nn.ReLU(),
            ),
            ResNetBlock(layers[0],64,64,256,addpool=True),
            ResNetBlock(layers[1],256,128,512,stride=2),
            ResNetBlock(layers[2],512,256,1024,stride=1,dilation=2),
            ResNetBlock(layers[3],1024,512,2048,stride=1,dilation=4),
        )
        self.head = torch.nn.Conv2d(2048,classes,kernel_size=(1,1),padding='same')
        self.upsampler = torch.nn.Upsample(scale_factor=8, mode='bilinear')
    def forward(self, x):
        x = self.blocks(x)
        y = self.head(x)
        y = self.upsampler(y)
        return y
class CCNet(DilatedResNet):
    def __init__(self,in_channels,classes,layers=[3,4,6,3],recurrence=2):
        super().__init__(in_channels,classes,layers=[3,4,6,3])
        self.head = RCCAModule(2048, 512, classes)
        self.upsampler = torch.nn.Upsample(scale_factor=8, mode='bilinear')
        self.recur = recurrence
    def forward(self, x):
        x = self.blocks(x)
        y = self.head(x,self.recur)
        y = self.upsampler(y)
        return y
class DilatedResNetWithPPM(DilatedResNet):
    def __init__(self,in_channels,classes,layers=[3,4,6,3]):
        super().__init__(in_channels,classes,layers=[3,4,6,3])
        self.ppm = PPM(2048)
        self.head = torch.nn.Conv2d((2048//4)*4+2048,classes,kernel_size=(1,1),padding='same')
    def forward(self, x):
        x = self.blocks(x)
        x = self.ppm(x)
        y = self.head(x)
        y = self.upsampler(y)
        return y

# Trainer

In [None]:
def iou(pred, target, n_classes = 2, include_bg_class = True):
    ious = []
    # Assuming the shapes are BATCH x 1 x H x W => BATCH x H x W
    pred = pred.cpu()
    target = target.cpu()
    for cls in range(0 if include_bg_class else 1, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
        pred2 = pred == cls
        target2 = target == cls
        intersection = torch.logical_and(pred2,target2).count_nonzero((1,2)).data.cpu().numpy()
        #union = torch.logical_or(pred2,target2).count_nonzero((1,2)).data.cpu().item()
        union = pred2.count_nonzero((1,2)).data.cpu().numpy() + target2.count_nonzero((1,2)).data.cpu().numpy() - intersection
        ious.append((intersection/union))
    return ious
class EarlyStopper:
    def __init__(self, metric_name='val_loss', lower_is_better=True, patience=4, delta=0):
        self.metric_name = metric_name
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.lower_is_better = lower_is_better
        if self.lower_is_better:
            self.best_metric = float('inf')
        else:
            self.best_metric = -float('inf')
    def early_stop(self, history, lower_is_better=True):
        metric = history[self.metric_name][-1]
        if self.lower_is_better:
            if metric < self.best_metric:
                self.best_metric = metric
                self.counter = 0
                return -1
            elif metric > (self.best_metric + self.delta):
                self.counter += 1
                if self.counter >= self.patience:
                    return 1
            return 0
        else:
            if metric > self.best_metric:
                self.best_metric = metric
                self.counter = 0
                return -1
            elif metric < (self.best_metric + self.delta):
                self.counter += 1
                if self.counter >= self.patience:
                    return 1
            return 0

In [None]:
def ohe(y):
    return torch.nn.functional.one_hot(y,7).float() #cursed
class Manager():
    def __init__(self, train_dl, val_dl, model, loss_fn, optimizer, preprocesser = None, scheduler = None):
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.preprocesser = preprocesser
        self.scheduler = scheduler
        parametertensor = next(self.model.parameters())
        self.to_device = parametertensor.device
        self.rescaler = torchvision.transforms.v2.ToDtype(parametertensor.dtype,True) #infer the dtype of input
        self.stopnow = False
        self.personalbest = False
        self.pbrecord = None
        self.history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'best_epoch': 0}
    def train_one_epoch(self):
        size = len(self.train_dl.dataset)
        count, count_true = 0, 0
        batchsizes, losses = [], []
        self.model.train()
        for X, y in self.train_dl:
            #preprocess layers here
            X = self.rescaler(X).to(device=self.to_device,memory_format=torch.channels_last)
            #y = y.to(dtype=torch.int64)
            y = y.to(device=self.to_device)
            if self.preprocesser != None:
                X = self.preprocesser(X)
            # Compute prediction and loss
            pred = self.model(X)
            loss = self.loss_fn(pred, y)
            # Backpropagation
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            # Stop propagating gradient
            loss = loss.item()
            # and trues
            pred = pred.argmax(1)
            count_true += (pred==y.argmax(1)).count_nonzero().item()
            #
            batchsizes.append(len(X))
            losses.append(loss)
            count += len(X)
            print(f"loss: {loss:>7f}  [{count:>5d}/{size:>5d}]",end = '\r')
        print()
        meanloss = np.average(np.array(losses),weights=np.array(batchsizes))
        accuracy = count_true/count
        print(f"Training Error:   Accuracy: {(100*accuracy):>.02f}%, Avg loss: {meanloss:>8f}")
        self.history['train_loss'].append(meanloss)
        self.history['train_acc'].append(accuracy)
    def validate(self, earlystop):
        size = len(self.val_dl.dataset)
        self.model.eval()
        num_batches = len(self.val_dl)
        count, count_true = 0, 0
        batchsizes, losses = [], []
        with torch.no_grad(): # reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
            for X, y in self.val_dl:
                X = self.rescaler(X).to(device=self.to_device,memory_format=torch.channels_last)
                y = y.to(device=self.to_device,dtype=torch.int64)
                pred = self.model(X)
                val_loss = self.loss_fn(pred, y).item()
                # and trues
                pred = pred.argmax(1)
                count_true += (pred==y).count_nonzero().item()
                #
                batchsizes.append(len(X))
                losses.append(val_loss)
                count += len(X)
        meanloss = np.average(np.array(losses),weights=np.array(batchsizes))
        accuracy = count_true/count
        print(f"Validating Error: Accuracy: {(100*accuracy):>.02f}%, Avg loss: {meanloss:>8f}")
        self.history['val_loss'].append(meanloss)
        self.history['val_acc'].append(accuracy)
        earlystopout = earlystop.early_stop(self.history)
        self.stopnow = earlystopout == 1
        self.personalbest = earlystopout == -1
    def evaluate(self, dataloader):
        size = len(dataloader.dataset)
        self.model.eval()
        num_batches = len(dataloader)
        count, count_true = 0, 0
        batchsizes, losses, mious = [], [], []
        with torch.no_grad(): # reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
            for X, y in dataloader:
                X = self.rescaler(X).to(device=self.to_device,memory_format=torch.channels_last)
                y = y.to(device=self.to_device,dtype=torch.int64)
                pred = self.model(X)
                loss = self.loss_fn(pred, y).item()
                # and trues
                pred = pred.argmax(1)
                count_true += (pred==y).count_nonzero().item()
                #
                batchsizes.append(len(X))
                losses.append(loss)
                count += len(X)
        meanloss = np.average(np.array(losses),weights=np.array(batchsizes))
        accuracy = count_true/count
        return {'loss': meanloss, 'acc': accuracy}
    def predict(self, X):
        self.model.eval()
        with torch.no_grad():
            X = self.rescaler(X.unsqueeze(0)).to(device=self.to_device,memory_format=torch.channels_last)
            pred = self.model(X)
        return pred.detach()
    def train(self, epochs, earlystop, restore_best_weights = True):
        epoch = 0
        while epoch != epochs:
            if self.scheduler != None:
                print(f"Epoch {epoch+1} ({self.scheduler.get_last_lr()})\n-------------------------------")
            else:
                print(f"Epoch {epoch+1}\n-------------------------------")
            #with torch.autograd.detect_anomaly():
            self.train_one_epoch()
            self.validate(earlystop)
            if self.scheduler != None:
                if type(self.scheduler) is torch.optim.lr_scheduler.ReduceLROnPlateau:
                    self.scheduler.step(self.history['val_loss'][epoch])
                else:
                    self.scheduler.step()
            if self.personalbest:
                print('New personal best!')
                if restore_best_weights:
                    #torch.save(model.state_dict(),'temp-personalbest.pth')
                    self.pbrecord = copy.deepcopy(model.state_dict())
                self.history['best_epoch'] = epoch #zero-indexed
            if self.stopnow:
                print('Stop.')
                print()
                break
            print()
            epoch += 1
        print('Training ended.')
        if restore_best_weights:
            print('Restoring best weights...')
            #self.model.load_state_dict(torch.load('temp-personalbest.pth'))
            self.model.load_state_dict(self.pbrecord)
        return self.history
def dice_loss(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.contiguous().view(num, -1).float()  # Flatten
    m2 = target.contiguous().view(num, -1).float()  # Flatten
    intersection = (m1 * m2).sum().float()
    dice = (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
    return 1-dice

# Config

In [None]:
def init_lazies(model):
    # Initialize lazy modules by passing at least something
    with torch.no_grad():
        model.eval()
        model(torch.rand(1, 3, HEIGHT, WIDTH).to(device=DEVICE,dtype=DTYPE))
def init_weights(m):
    print(type(m))
    if type(m) is torch.nn.Linear or type(m) is torch.nn.LazyLinear\
    or type(m) is torch.nn.Conv2d or type(m) is torch.nn.LazyConv2d\
    or type(m) is torch.nn.ConvTranspose2d or type(m) is torch.nn.LazyConvTranspose2d:
        if type(m.weight) is torch.nn.UninitializedParameter:
            print("Skipping uninitialized weights....")
            assert(False)
        print("Shape: "+str(m.weight.shape)+" + "+str(m.bias.shape if m.bias is not None else ""))
        if not m.weight.requires_grad:
            print("Weights frozen, skipping....")
        elif len(m.weight.shape) > 1:
            #torch.nn.init.xavier_normal_(m.weight, gain=1.4142135623730950488016887242097)
            torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            #torch.nn.init.normal_(m.weight,0,0.01)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
            print("Applied He initializer!")
        else:
            print("He initializer not applicable for 1 dimensional tensor.")
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size(),device=tensor.device) * self.std + self.mean
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class MakeOHE(object):
    def __init__(self):
        pass
    def __call__(self, tensor, tensor2):
        return (tensor, ohe(tensor2))
    def __repr__(self):
        return self.__class__.__name__

# Main

In [None]:
dataset_dir = r'C:\Users\User\Documents\AAAAA\eseg2\SB-FishDisease-P20-20-EP20-splitAA'
DATASETNAME = 'Fishv2_CLS_DE_AA'
train_img_dir = dataset_dir+r'\train'
val_img_dir = dataset_dir+r'\valid'
test_img_dir = dataset_dir+r'\test'

print(torch.hub.list('pytorch/vision'))

In [None]:
DTYPE = torch.float32
EPOCHS = 2000
RBW = True

BATCH_SIZE = 8
HEIGHT, WIDTH = 448, 448

scheduler = None

''' Model selection:
VGG13
PRE_VGG13
VGG16
PRE_VGG16
PRE_ResNet18-test
PRE_ResNet34
ResNet50
PRE_ResNet50
ResNet101
PRE_ResNet101
'''

MODELNAME = 'PRE_VGG16'

match MODELNAME:
    case 'VGG16':
        model = torch.hub.load('pytorch/vision', 'vgg16_bn', weights=None)
        model.classifier[6] = torch.nn.Linear(4096,7)
        model.apply(init_weights)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=100)
    case 'PRE_VGG16':
        model = torch.hub.load('pytorch/vision', 'vgg16_bn', weights="VGG16_BN_Weights.IMAGENET1K_V1")
        model.classifier[6] = torch.nn.Linear(4096,7)
        init_weights(model.classifier[6])
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=10) ######
    case 'PRE_ResNet18-test':
        BATCH_SIZE = 43
        model = torch.hub.load('pytorch/vision', 'resnet18', weights='ResNet18_Weights.IMAGENET1K_V1')
        model.fc = torch.nn.Linear(512,7)
        init_weights(model.fc)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=10) ###
    case 'PRE_ResNet34':
        model = torch.hub.load('pytorch/vision', 'resnet34', weights='ResNet34_Weights.IMAGENET1K_V1')
        model.fc = torch.nn.Linear(512,7)
        init_weights(model.fc)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=10)
    case 'ResNet50':
        model = torch.hub.load('pytorch/vision', 'resnet50', weights=None)
        model.fc = torch.nn.Linear(2048,7)
        model.apply(init_weights)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=100)
    case 'PRE_ResNet50':
        model = torch.hub.load('pytorch/vision', 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V2')
        model.fc = torch.nn.Linear(2048,7)
        init_weights(model.fc)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=10)
    case 'ResNet101':
        model = torch.hub.load('pytorch/vision', 'resnet101', weights=None)
        model.fc = torch.nn.Linear(2048,7)
        model.apply(init_weights)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=100)
    case 'PRE_ResNet101':
        model = torch.hub.load('pytorch/vision', 'resnet101', weights='ResNet101_Weights.IMAGENET1K_V2')
        model.fc = torch.nn.Linear(2048,7)
        init_weights(model.fc)
        model = model.to(device=DEVICE,dtype=DTYPE)
        model.requires_grad_(True)
        antilatestopper = EarlyStopper('val_loss',lower_is_better=True,patience=10)
    case _:
        assert(False)

'''Optimizer selection:
SGD_1e-4_0,9 #VGG16 -> Good
SGD_3e-4_0,9
SGD_3e-5_0,9
SGD_1e-3_0,9_1e-4 #ResNet -> Good!
SGD_1e-4_0,9_1e-5
SGD_3e-5_0,9_1e-5
SGD_1e-3U_0,9_1e-4
SGD_1e-3_0,9_1e-5
Adam_5e-5
Adam_1e-5
'''

OPTIMIZERNAME = 'SGD_1e-4_0,9'

match OPTIMIZERNAME:
    case 'SGD_1e-4_0,9':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=0)
    case 'SGD_3e-4_0,9':
        optimizer = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.9, weight_decay=0)
    case 'SGD_3e-5_0,9':
        optimizer = torch.optim.SGD(model.parameters(), lr=3e-5, momentum=0.9, weight_decay=0)
    case 'SGD_1e-3_0,9_1e-4':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    case 'SGD_1e-4_0,9_1e-5':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=1e-5)
    case 'SGD_3e-5_0,9_1e-5':
        optimizer = torch.optim.SGD(model.parameters(), lr=3e-5, momentum=0.9, weight_decay=1e-5)
    case 'SGD_1e-3U_0,9_1e-4':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.316227766, patience=5, threshold=0, cooldown=0, min_lr=1e-4)
        scheduler.cooldown_counter = 5
    case 'SGD_1e-3_0,9_1e-5':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-5)
    case 'Adam_5e-5':
        optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0)
    case 'Adam_1e-5':
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=0)
    case 'SGD_1e-4_0,9':
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=0)
    case _:
        assert(False)

In [None]:
train_ds = ImageGenerator(train_img_dir,size=(HEIGHT,WIDTH))
val_ds = ImageGenerator(val_img_dir,size=(HEIGHT,WIDTH))
test_ds = ImageGenerator(test_img_dir,size=(HEIGHT,WIDTH))

PREPREPROCESSOR = 'CutMixUp' #MixUp > CutMix > None

match PREPREPROCESSOR:
    case 'None':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            MakeOHE()
        ])
    case 'MixUp':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            torchvision.transforms.v2.RandomChoice([
                MakeOHE(),
                torchvision.transforms.v2.MixUp(num_classes=7)
            ]) # ,p=[0.9,0.1]
        ])
    case 'CutMix':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            torchvision.transforms.v2.RandomChoice([
                MakeOHE(),
                torchvision.transforms.v2.CutMix(num_classes=7)
            ]) #
        ])
    case 'CutMixUp':
        pretransform = torchvision.transforms.v2.Compose([
            torchvision.transforms.v2.ToDtype(torch.float32, True),
            torchvision.transforms.v2.RandomChoice([
                MakeOHE(),
                torchvision.transforms.v2.MixUp(num_classes=7),
                torchvision.transforms.v2.CutMix(num_classes=7)
            ])
        ])
    case _:
        assert(False)

def collate_fn(batch):
    return pretransform(*torch.utils.data.default_collate(batch))

train_dataloader = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,pin_memory=True,collate_fn=collate_fn)
val_dataloader = torch.utils.data.DataLoader(val_ds,batch_size=BATCH_SIZE,pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(test_ds,batch_size=BATCH_SIZE)

print(model)
print(model.parameters())

In [None]:
transforms = torchvision.transforms.v2.Compose([
    AddGaussianNoise(0,0.0125),
    torchvision.transforms.v2.RandomChoice([
        torchvision.transforms.v2.RandomAffine(0),
        torchvision.transforms.v2.GaussianBlur(5),
        torchvision.transforms.v2.GaussianBlur(9),
    ]),
    torchvision.transforms.v2.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.v2.RandomVerticalFlip(p=0.5),
    torchvision.transforms.v2.ColorJitter(0.25,0.0,0.25),
    torchvision.transforms.v2.RandomChoice([
        torchvision.transforms.v2.RandomAffine(0),
        torchvision.transforms.v2.ElasticTransform(alpha=768, sigma=14),
        torchvision.transforms.v2.ElasticTransform(alpha=2048, sigma=20)
    ]),
    torchvision.transforms.v2.RandomRotation((-15,15))
]).to(DEVICE)

#weight = torch.tensor([23, 20, 15, 24, 90, 21, 20]) #Train only
weight = torch.tensor([28, 25, 19, 29, 109, 26, 24]) #Train and Valid
#weight = torch.tensor([33, 29, 22, 34, 128, 30, 28]) #Train, Valid, and Test
print(weight)
weight = np.mean(weight.numpy())/weight
print(weight)
loss_fn = torch.nn.CrossEntropyLoss(weight=weight.to(device=DEVICE,dtype=DTYPE)) #Crossentropy requires logits, no softmax required
LOSSNAME = 'CE_wprop'

print(scheduler)

model_manager = Manager(train_dataloader, val_dataloader, model, loss_fn, optimizer, transforms, scheduler)

In [None]:
model_history = None
model_history = model_manager.train(EPOCHS, antilatestopper, restore_best_weights = RBW)


# Results

In [None]:
del train_dataloader
train_dataloader = torch.utils.data.DataLoader(train_ds,batch_size=BATCH_SIZE)
del val_dataloader
val_dataloader = torch.utils.data.DataLoader(val_ds,batch_size=BATCH_SIZE)
train_metric = model_manager.evaluate(train_dataloader)
val_metric = model_manager.evaluate(val_dataloader)
test_metric = model_manager.evaluate(test_dataloader)

print(f"Training Error: Accuracy: {(100*train_metric['acc']):>.3f}%, Avg loss: {train_metric['loss']:>8f}")
print(f"Validating Error: Accuracy: {(100*val_metric['acc']):>.3f}%, Avg loss: {val_metric['loss']:>8f}")
print(f"Testing Error: Accuracy: {(100*test_metric['acc']):>.3f}%, Avg loss: {test_metric['loss']:>8f}")


In [None]:
best_metrics = ('{:.2%}'.format(train_metric['acc'])+'_'+'{:.2%}'.format(val_metric['acc'])+'_'+'{:.2%}'.format(test_metric['acc'])).replace('.',',')
size = str(WIDTH)+'^2' if WIDTH == HEIGHT else str(WIDTH)+'Ã—'+str(HEIGHT)
precision_name = "-half" if next(model.parameters()).dtype == torch.float16 else "-bhalf" if next(model.parameters()).dtype == torch.bfloat16 else ''
NAME = MODELNAME+precision_name+'-'+size+'-'+DATASETNAME+"+"+PREPREPROCESSOR+'-'+OPTIMIZERNAME+'-'+LOSSNAME+'-BS'+str(BATCH_SIZE) \
+'-E'+str(model_history['best_epoch'])+'-'+best_metrics
print(NAME)

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=1)
val_dl = torch.utils.data.DataLoader(val_ds,batch_size=1)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=1)
train_count = np.zeros((7, 7), dtype=int)
val_count = np.zeros((7, 7), dtype=int)
test_count = np.zeros((7, 7), dtype=int)
for i, j in train_dl:
    true = j
    pred = model_manager.predict(i.squeeze(0)).cpu().argmax(1).item()
    train_count[true][pred]+=1
for i, j in val_dl:
    true = j
    pred = model_manager.predict(i.squeeze(0)).cpu().argmax(1).item()
    val_count[true][pred]+=1
for i, j in test_dl:
    true = j
    pred = model_manager.predict(i.squeeze(0)).cpu().argmax(1).item()
    test_count[true][pred]+=1
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(train_count)
for i in range(7):
    for j in range(7):
        text = plt.text(j, i, train_count[i, j], ha="center", va="center", color="w")
plt.title('Train')
plt.xlabel('Prediction')
plt.ylabel('Actual')
plt.subplot(1, 3, 2)
plt.imshow(val_count)
for i in range(7):
    for j in range(7):
        text = plt.text(j, i, val_count[i, j], ha="center", va="center", color="w")
plt.title('Valid')
plt.xlabel('Prediction')
plt.ylabel('Actual')
plt.subplot(1, 3, 3)
plt.imshow(test_count)
for i in range(7):
    for j in range(7):
        text = plt.text(j, i, test_count[i, j], ha="center", va="center", color="w")
plt.title('Test')
plt.xlabel('Prediction')
plt.ylabel('Actual')
plt.savefig(NAME+'-Result') #save before show
plt.show()

In [None]:
train_loss = model_history['train_loss']
val_loss = model_history['val_loss']
train_acc = model_history['train_acc']
val_acc = model_history['val_acc']
best_epoch = model_history['best_epoch']
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(range(len(train_loss)), train_loss, '-r', label='Training')
plt.plot(range(len(val_loss)), val_loss, '-b', label='Validation')
plt.axvline(x=best_epoch)
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 2.5])
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(range(len(train_acc)), train_acc, '-r', label='Training')
plt.plot(range(len(val_acc)), val_acc, '-b', label='Validation')
plt.axvline(x=best_epoch)
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('acc')
plt.ylim([0, 1])
plt.legend()
plt.savefig(NAME+'-TrainPlot')
plt.show()

In [None]:
torch.save(model, NAME+'.pth')

In [None]:
with open("Log-CLS.txt", "a") as f:
    f.write('Name: ')
    f.write(NAME)
    f.write('\n')
    f.write(str(loss_fn))
    f.write('\n')
    f.write(str(optimizer))
    f.write('\n')
    f.write(str(transforms))
    f.write('\n')
    f.write('History: (Train Loss, Val Loss, Train Accuracy, Val Accuracy)\n')
    for i in train_loss:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
    for i in val_loss:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
    for i in train_acc:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
    for i in val_acc:
        f.write(str(i))
        f.write('\t')
    f.write('\n')
    f.write('-------------------------------\n')