In [None]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import random
from torchvision import models, datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import sklearn.metrics as metric
import shutil, glob
from tqdm import tqdm
from PIL import Image
import pickle

In [None]:
import cv2
import numpy as np


def img_range(x):
    y = (x + abs(x.min()))
    x = y / y.max()
    x = x * 255
    x = x.astype(np.uint8)
    return x


def DFT(img, offset=11, mode='HP'):
    dft = cv2.dft(np.float32(img), flags=cv2.DFT_COMPLEX_OUTPUT)
    dft_shift = np.fft.fftshift(dft)
    rows, cols = img.shape
    crow, ccol = rows // 2, cols // 2
    if mode == 'LP':
        mask = np.zeros((rows, cols, 2), np.uint8)
        mask[crow - offset:crow + offset, ccol - offset:ccol + offset] = 1
    if mode == 'HP':
        mask = np.ones((rows, cols), np.uint8)*255
        x, y = crow - 3*offset//2, ccol - 3*offset//2
        ###################
        mask[:, ccol - offset//2:ccol + offset//2] = 0
        mask[crow - offset//2:crow + offset//2, :] = 0
        cv2.rectangle(mask, (x, y), (rows-x, cols-y), 0, -1)
        mask = mask//255
        mask = np.dstack([mask, mask])
        ######################

    # apply mask and inverse DFT
    dft_shift_masked = dft_shift * mask
    inv_masked = np.fft.ifftshift(dft_shift_masked)
    imginv_masked = cv2.idft(inv_masked)
    img_dft = cv2.magnitude(imginv_masked[:, :, 0], imginv_masked[:, :, 1])
    return img_dft


In [None]:
class elpv(torch.utils.data.Dataset):
    def __init__(self, path, mode='train', types=['mono', 'poly']):

        self.mode = mode
        self.type = types
        self.path = path
        # self.data_path = os.path.join(path, 'train') if mode == 'train' else os.path.join(path, 'val')
        self.infos_path = os.path.join(path, f"elpv_infos_{types}_train.pkl") \
                            if mode == 'train' else os.path.join(path, f"elpv_infos_{types}_test.pkl")

        with open(self.infos_path, 'rb') as f:
            self.samples = pickle.load(f)
        
        self.to_tensor = transforms.ToTensor()

        self.transform_t_geometric = transforms.Compose([
            transforms.Resize((288,288)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.2),
        ])

        self.transform_t_colour = transforms.Compose([
            transforms.ColorJitter(brightness=(0.8, 1.2), contrast=(0.7, 1.3))
        ])

        self.resized_v = transforms.Resize((288,288))
        self.normalized_tv_img = transforms.Normalize(mean=[0.5968], std=[0.0977])
        self.normalized_tv_filter = transforms.Normalize(mean=[0.1], std=[0.0984])
        


    def __getitem__(self, index):
        img_path = os.path.join(self.path, self.samples[index]["Path"])
        lab = int(self.samples[index]["Class"])
        
        if self.samples[index]["Type"] != self.type:
            raise f"TypeMatchError: Having issue in {self.samples[index]}"

        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        filter_image = img_range(DFT(image, offset=5, mode='HP'))

        if self.mode == "train":
            image = self.to_tensor(image)
            filter_image = self.to_tensor(filter_image)
            img_cat_filter = torch.cat([image, filter_image], dim=0)
            img_cat_filter_geo = self.transform_t_geometric(img_cat_filter)
            filter_image = img_cat_filter_geo[1].unsqueeze(0)
            image_colour_aug = self.transform_t_colour(img_cat_filter_geo[0].unsqueeze(0))
            image = self.normalized_tv_img(image_colour_aug)
            filter_image = self.normalized_tv_filter(filter_image)
            # image = image_colour_aug
        else:
            image = self.to_tensor(image)
            filter_image = self.to_tensor(filter_image)
            image = self.resized_v(image)
            filter_image = self.resized_v(filter_image)
            image = self.normalized_tv_img(image)
            filter_image = self.normalized_tv_filter(filter_image)

        label = torch.zeros(2)
        label[lab] = 1

        return image, filter_image, label, self.samples[index]["Path"]
    
    def __len__(self):
        return len(self.samples)

In [None]:
mod = 'mono'
train_data = elpv(f"/kaggle/input/elpv-and-pvel-ad-dataset/elpv_{mod}", mode='train',types=mod)
val_data = elpv(f"/kaggle/input/elpv-and-pvel-ad-dataset/elpv_{mod}", mode='val', types=mod)

In [None]:
train_loader=torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True , num_workers = 4)
val_loader=torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=True , num_workers = 4)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.utils.model_zoo as model_zoo

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
                               padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes,
                 stride=1, downsample=None, groups=1, base_width=64,
                 dilation=1, norm_layer=None,
                 activation=nn.ReLU(inplace=True), residual_only=False):

        super(BasicBlock, self).__init__()
        self.residual_only = residual_only
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and '
                             'base_width=64')
        # Both self.conv1 and self.downsample layers downsample the input
        # when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)
        self.bn1 = norm_layer(planes)
        self.act = activation
        self.conv2 = conv3x3(planes, planes, dilation=dilation)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        if self.residual_only:
            return out
        out = out + identity
        out = self.act(out)

        return out


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False,
                     dilation=dilation)


class DCMAF(nn.Module):
    def __init__(self, num_channel, out_channel):
        super().__init__()

        # todo add convolution here
        self.pool = nn.AdaptiveAvgPool2d(1)  # [B, C, 1, 1]

        self.conv1 = nn.Conv2d(num_channel, out_channel, kernel_size=1)
        self.conv2 = nn.Conv2d(num_channel, out_channel, kernel_size=1)
        self.conv3 = nn.Conv2d(out_channel, num_channel, kernel_size=1)
        self.conv4 = nn.Conv2d(out_channel, num_channel, kernel_size=1)
        self.activation = nn.Sigmoid()

    def forward(self, image, filter_img):
        image = self.pool(image)
        filter_img = self.pool(filter_img)

        image = F.relu(self.conv1(image))
        filter_img = F.relu(self.conv2(filter_img))

        image = self.conv3(image)
        filter_img = self.conv4(filter_img)

        diff = image - filter_img
        weight = self.activation(diff)
        w1 = weight
        w2 = 1 - weight
        return w1, w2


class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=2, pretrained_url=None, mode='train'):

        super(ResNet, self).__init__()
        # orginal image
        self.inplanes = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # filtered image
        self.inplanes = 64
        self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1_d = nn.BatchNorm2d(64)
        self.relu_d = nn.ReLU(inplace=True)
        #         self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1_d = self._make_layer(block, 64, layers[0])
        self.layer2_d = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3_d = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4_d = self._make_layer(block, 512, layers[3], stride=2)

        ##### dcmaf fusion
        self.fusion_2 = DCMAF(64, 32)
        self.fusion_3 = DCMAF(128, 64)
        self.fusion_4 = DCMAF(256, 128)
        self.fusion_5 = DCMAF(512, 256)
        ######

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.Linear(256, num_classes)
        )

        if isinstance(pretrained_url, str) and mode == 'train':
            self.pretrained_url = pretrained_url
            self._load_resnet_pretrained()
            print("********************** Pretrained model loaded **********************")

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def _load_resnet_pretrained(self):
        # pretrain_dict = torch.load(self.pretrained_path)
        pretrain_dict = model_zoo.load_url(self.pretrained_url)
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                if k.startswith('conv1'):
                    model_dict[k] = torch.mean(v, 1).data. \
                        view_as(state_dict[k])
                    model_dict[k.replace('conv1', 'conv1_d')] = torch.mean(v, 1).data. \
                        view_as(state_dict[k.replace('conv1', 'conv1_d')])

                elif k.startswith('bn1'):
                    model_dict[k] = v
                    model_dict[k.replace('bn1', 'bn1_d')] = v
                elif k.startswith('layer'):
                    model_dict[k] = v
                    model_dict[k[:6] + '_d' + k[6:]] = v

        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

    def forward(self, x, f):
        out = F.relu(self.bn1(self.conv1(x)))
        out_f = F.relu(self.bn1_d(self.conv1_d(f)))

        out = self.layer1(out)
        out_f = self.layer1_d(out_f)

        w_i, w_f = self.fusion_2(out, out_f)
        img_w = out.mul(w_i)
        fil_w = out_f.mul(w_f)
        out = out + fil_w
        out_f = out_f + img_w
        ###########

        out = self.layer2(out)
        out_f = self.layer2_d(out_f)

        w_i, w_f = self.fusion_3(out, out_f)
        img_w = out.mul(w_i)
        fil_w = out_f.mul(w_f)
        out = out + fil_w
        out_f = out_f + img_w
        ###########

        out = self.layer3(out)
        out_f = self.layer3_d(out_f)

        w_i, w_f = self.fusion_4(out, out_f)
        img_w = out.mul(w_i)
        fil_w = out_f.mul(w_f)
        out = out + fil_w
        out_f = out_f + img_w
        ###########

        out = self.layer4(out)
        out_f = self.layer4_d(out_f)

        w_i, w_f = self.fusion_5(out, out_f)
        img_w = out.mul(w_i)
        fil_w = out_f.mul(w_f)
        out = img_w + fil_w
        ##########

        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out



In [None]:
from torch.optim import SGD

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

num_epoch = 60

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3], pretrained_url = model_urls['resnet34'])

if torch.cuda.device_count()>1:
    print("let use",torch.cuda.device_count(),"gpu")
    model=nn.DataParallel(model)

model=model.to(device)

# build optimizer and scheduler
optimizer = SGD(model.parameters(), lr=0.006, momentum=0.9, weight_decay=0.0005)
lambda1 = lambda epoch: ((1 - (epoch / num_epoch)) ** 0.9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda1)

class_weights = [len(train_data)/529, len(train_data)/277]
class_weights = torch.Tensor(class_weights).to(device)
# Loss functions
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [None]:
import time
import torch
import numpy as np
import sklearn.metrics as metric
from tqdm import tqdm



epoch_losses_train = []
epoch_losses_val = []
best_f1_weighted, best_f1_binary = 0, 0

start = time.time()
for epoch in range(1, num_epoch+1):
    print("epoch: %d/%d" % (epoch, num_epoch))
    ############################################################################
    # train:
    ############################################################################
    
    model.train()
    batch_losses = []
    for imgs, filter_img, label, _ in train_loader:
        imgs = imgs.to(device)
        filter_img = filter_img.to(device)
        label = (label.type(torch.float32)).to(device)

        optimizer.zero_grad()  # (reset gradients)
        
        out = model(imgs, filter_img)
    
        loss = criterion(out, label)
    
        loss_value = loss.data.detach().cpu().numpy()
        batch_losses.append(loss_value)
    
        # optimization step:
        
        loss.backward()
        optimizer.step()  # (perform optimization step)
    
    epoch_loss = np.mean(batch_losses)
    epoch_losses_train.append(epoch_loss)
    print("train loss: %g" % epoch_loss)
    
    scheduler.step()
    
    ############################################################################
    # val:
    ############################################################################
    t = []
    p = []
    if True:
        model.eval()
        batch_losses = []
        for imgs, filter_img, label, _ in val_loader:
            with torch.no_grad():
                imgs = imgs.to(device)
                filter_img = filter_img.to(device)
                label = (label.type(torch.float32)).to(device)
    
                out = model(imgs, filter_img)
    
                preds = out.detach().max(dim=1)[1].cpu().numpy()
                targets = label.detach().max(dim=1)[1].cpu().numpy()
    
                t.extend(targets)
                p.extend(preds)
    
                # compute the loss:
                loss = criterion(out, label)
                loss_value = loss.data.cpu().numpy()
                batch_losses.append(loss_value)
    
        epoch_loss = np.mean(batch_losses)
        epoch_losses_val.append(epoch_loss)
        print("val loss: %g" % epoch_loss)
    
        f1_weighted = metric.f1_score(t, p, average='weighted')
        print("f1 score: ", f1_weighted)
        f1_binary = metric.f1_score(t, p, average='binary')
        print("f1 binary: ", f1_binary)
    
    if f1_weighted > best_f1_weighted:
        print("############ Best Result f1_weighted ############")
        print(metric.classification_report(t, p, zero_division=0.0))
        out_filename = os.path.join('/kaggle/working/', f'best_f1_weighted_{mod}.pth')
        torch.save(model.state_dict(), out_filename)
        best_f1_weighted = f1_weighted
    
    if f1_binary > best_f1_binary:
        print("############ Best Result f1_binary ############")
        print(metric.classification_report(t, p, zero_division=0.0))
        out_filename = os.path.join('/kaggle/working/', f'best_f1_binary_{mod}.pth')
        torch.save(model.state_dict(), out_filename)
        best_f1_binary = f1_binary
    
    end = time.time()
    forward_time = end - start
    if forward_time > 18000:
        print(f"############ last=> epoch {epoch} time :{forward_time} ############")
        print(metric.classification_report(t, p, zero_division=0.0))
        out_filename = os.path.join('/kaggle/working/',  f'last_weights_{mod}.pth')
        torch.save(model.state_dict(), out_filename)
        break


In [None]:
a = np.arange(len(epoch_losses_val))
fig, ax = plt.subplots()
gloss = ax.plot(a, epoch_losses_train, label="train")
dloss = ax.plot(a, epoch_losses_val, label="val")
# ax.legend(handles=[gloss, dloss])
fig.legend()
plt.savefig(f'/kaggle/working/losses_plot_{mod}.png')

In [None]:
model.load_state_dict(torch.load(f'/kaggle/working/best_f1_weighted_{mod}.pth'))
fig, axs = plt.subplots(2, 7, figsize=(30,30),layout="compressed")
samples = random.choice(list(val_loader))
while samples[0].shape[0] <14:
    print(samples[0].shape[0])
    samples = random.choice(list(val_loader))

axs = axs.flatten()
for i in range(14):
    imgs = samples[0][i].unsqueeze(dim=0)
    filter_img = samples[1][i].unsqueeze(dim=0)
    label = samples[2][i]
    pred = model(imgs, filter_img)
    pred = pred.detach().max(dim=1)[1].cpu().numpy()
    axs[i].imshow(imgs[0].permute(1,2,0).detach().cpu().numpy(), cmap='gray')
    axs[i].axis("off")
    axs[i].set_title(f"label = {label.max(dim=0)[1].numpy()}, pred = {pred.item()}", fontsize=20)
    if i>=13:
        break

plt.savefig(f'/kaggle/working/PV_output_{mod}.png', bbox_inches='tight', pad_inches=0, dpi=400)

In [None]:
import time
import torch
import numpy as np
import sklearn.metrics as metric
from tqdm import tqdm

from torch.optim import SGD

mod = 'poly'
train_data = elpv(f"/kaggle/input/elpv-and-pvel-ad-dataset/elpv_{mod}", mode='train',types=mod)
val_data = elpv(f"/kaggle/input/elpv-and-pvel-ad-dataset/elpv_{mod}", mode='val', types=mod)
train_loader=torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True , num_workers = 4)
val_loader=torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=True , num_workers = 4)

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

num_epoch = 60

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3], pretrained_url = model_urls['resnet34'])

if torch.cuda.device_count()>1:
    print("let use",torch.cuda.device_count(),"gpu")
    model=nn.DataParallel(model)

model=model.to(device)

# build optimizer and scheduler
optimizer = SGD(model.parameters(), lr=0.006, momentum=0.9, weight_decay=0.0005)
lambda1 = lambda epoch: ((1 - (epoch / num_epoch)) ** 0.9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda1)

class_weights = [len(train_data)/824, len(train_data)/339]
class_weights = torch.Tensor(class_weights).to(device)
# Loss functions
criterion = nn.CrossEntropyLoss(weight=class_weights)

epoch_losses_train = []
epoch_losses_val = []
best_f1_weighted, best_f1_binary = 0, 0

start = time.time()
for epoch in range(1, num_epoch+1):
    print("epoch: %d/%d" % (epoch, num_epoch))
    ############################################################################
    # train:
    ############################################################################
    
    model.train()
    batch_losses = []
    for imgs, filter_img, label, _ in train_loader:
        imgs = imgs.to(device)
        filter_img = filter_img.to(device)
        label = (label.type(torch.float32)).to(device)

        optimizer.zero_grad()  # (reset gradients)
        
        out = model(imgs, filter_img)
    
        loss = criterion(out, label)
    
        loss_value = loss.data.detach().cpu().numpy()
        batch_losses.append(loss_value)
    
        # optimization step:
        
        loss.backward()
        optimizer.step()  # (perform optimization step)
    
    epoch_loss = np.mean(batch_losses)
    epoch_losses_train.append(epoch_loss)
    print("train loss: %g" % epoch_loss)
    
    scheduler.step()
    
    ############################################################################
    # val:
    ############################################################################
    t = []
    p = []
    if True:
        model.eval()
        batch_losses = []
        for imgs, filter_img, label, _ in val_loader:
            with torch.no_grad():
                imgs = imgs.to(device)
                filter_img = filter_img.to(device)
                label = (label.type(torch.float32)).to(device)
    
                out = model(imgs, filter_img)
    
                preds = out.detach().max(dim=1)[1].cpu().numpy()
                targets = label.detach().max(dim=1)[1].cpu().numpy()
    
                t.extend(targets)
                p.extend(preds)
    
                # compute the loss:
                loss = criterion(out, label)
                loss_value = loss.data.cpu().numpy()
                batch_losses.append(loss_value)
    
        epoch_loss = np.mean(batch_losses)
        epoch_losses_val.append(epoch_loss)
        print("val loss: %g" % epoch_loss)
    
        f1_weighted = metric.f1_score(t, p, average='weighted')
        print("f1 score: ", f1_weighted)
        f1_binary = metric.f1_score(t, p, average='binary')
        print("f1 binary: ", f1_binary)
    
    if f1_weighted > best_f1_weighted:
        print("############ Best Result f1_weighted ############")
        print(metric.classification_report(t, p, zero_division=0.0))
        out_filename = os.path.join('/kaggle/working/', f'best_f1_weighted_{mod}.pth')
        torch.save(model.state_dict(), out_filename)
        best_f1_weighted = f1_weighted
    
    if f1_binary > best_f1_binary:
        print("############ Best Result f1_binary ############")
        print(metric.classification_report(t, p, zero_division=0.0))
        out_filename = os.path.join('/kaggle/working/', f'best_f1_binary_{mod}.pth')
        torch.save(model.state_dict(), out_filename)
        best_f1_binary = f1_binary
    
    end = time.time()
    forward_time = end - start
    if forward_time > 24000:
        print(f"############ last=> epoch {epoch} time :{forward_time} ############")
        print(metric.classification_report(t, p, zero_division=0.0))
        out_filename = os.path.join('/kaggle/working/',  f'last_weights_{mod}.pth')
        torch.save(model.state_dict(), out_filename)
        break

a = np.arange(len(epoch_losses_val))
fig, ax = plt.subplots()
gloss = ax.plot(a, epoch_losses_train, label="train")
dloss = ax.plot(a, epoch_losses_val, label="val")
# ax.legend(handles=[gloss, dloss])
fig.legend()
plt.savefig(f'/kaggle/working/losses_plot_{mod}.png')

In [None]:
model.load_state_dict(torch.load(f'/kaggle/working/best_f1_weighted_{mod}.pth'))
fig, axs = plt.subplots(2, 7, figsize=(30,30),layout="compressed")
samples = random.choice(list(val_loader))
while samples[0].shape[0] <14:
    print(samples[0].shape[0])
    samples = random.choice(list(val_loader))

axs = axs.flatten()
for i in range(14):
    imgs = samples[0][i].unsqueeze(dim=0)
    filter_img = samples[1][i].unsqueeze(dim=0)
    label = samples[2][i]
    pred = model(imgs, filter_img)
    pred = pred.detach().max(dim=1)[1].cpu().numpy()
    axs[i].imshow(imgs[0].permute(1,2,0).detach().cpu().numpy(), cmap='gray')
    axs[i].axis("off")
    axs[i].set_title(f"label = {label.max(dim=0)[1].numpy()}, pred = {pred.item()}", fontsize=20)
    if i>=13:
        break

plt.savefig(f'/kaggle/working/PV_output_{mod}.png', bbox_inches='tight', pad_inches=0, dpi=400)