In [10]:
!pip install tensorboardX tqdm



In [11]:
# Essential packages for deep learning and data processing
!pip install torch torchvision torchaudio
!pip install tensorboardX
!pip install tqdm
!pip install numpy
!pip install Pillow  # for image processing
!pip install scikit-learn  # for train-test splitting utilities

# Additional utilities that might be needed
!pip install matplotlib  # for visualization if needed
!pip install einops  # often used with Vision Transformers
!pip install timm  # useful for ViT implementations



In [12]:
!pip install ml-collections



In [13]:
%%writefile cswin_unet.py

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops.layers.torch import Rearrange
import torch.utils.checkpoint as checkpoint
import numpy as np


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LePEAttention(nn.Module):
    def __init__(self, dim, resolution, idx, split_size, dim_out=None, num_heads=9, attn_drop=0., proj_drop=0.,
                 qk_scale=None):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out or dim
        self.resolution = resolution
        self.split_size = split_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if idx == -1:
            H_sp, W_sp = self.resolution, self.resolution
        elif idx == 0:
            H_sp, W_sp = self.resolution, self.split_size
        elif idx == 1:
            W_sp, H_sp = self.resolution, self.split_size
        else:
            print("ERROR MODE", idx)
            exit(0)
        self.H_sp = H_sp
        self.W_sp = W_sp
        stride = 1
        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)

        self.attn_drop = nn.Dropout(attn_drop)

    def im2cswin(self, x):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
        x = img2windows(x, self.H_sp, self.W_sp)
        x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
        return x

    def get_lepe(self, x, func):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)

        H_sp, W_sp = self.H_sp, self.W_sp
        x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp)  ### B', C, H', W'

        lepe = func(x)  ### B', C, H', W'
        lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3, 2).contiguous()

        x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3, 2).contiguous()
        return x, lepe

    def forward(self, qkv):
        """
        x: B L C
        """
        q, k, v = qkv[0], qkv[1], qkv[2]

        ### Img2Window
        H = W = self.resolution
        B, L, C = q.shape

        assert L == H * W, "flatten img_tokens has wrong size"

        q = self.im2cswin(q)
        k = self.im2cswin(k)
        v, lepe = self.get_lepe(v, self.get_v)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # B head N C @ B head C N --> B head N N
        attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
        attn = self.attn_drop(attn)

        x = (attn @ v) + lepe
        x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp, C)  # B head N N @ B head N C

        ### Window2Img
        x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)  # B H' W' C

        return x


class CSWinBlock(nn.Module):

    def __init__(self, dim, reso, num_heads,
                 split_size, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                 drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 last_stage=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.patches_resolution = reso
        self.split_size = split_size
        self.mlp_ratio = mlp_ratio
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm1 = norm_layer(dim)

        if self.patches_resolution == split_size:
            last_stage = True
        if last_stage:
            self.branch_num = 1
        else:
            self.branch_num = 2
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(drop)

        if last_stage:
            self.attns = nn.ModuleList([
                LePEAttention(
                    dim, resolution=self.patches_resolution, idx=-1,
                    split_size=split_size, num_heads=num_heads, dim_out=dim,
                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
                for i in range(self.branch_num)])

        else:
            self.attns = nn.ModuleList([
                LePEAttention(
                    dim // 2, resolution=self.patches_resolution, idx=i,
                    split_size=split_size, num_heads=num_heads // 2, dim_out=dim // 2,
                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
                for i in range(self.branch_num)])

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer,
                       drop=drop)
        self.norm2 = norm_layer(dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """

        H = W = self.patches_resolution
        B, L, C = x.shape
        assert L == H * W, "flatten img_tokens has wrong size"
        img = self.norm1(x)
        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)

        if self.branch_num == 2:
            x1 = self.attns[0](qkv[:, :, :, :C // 2])
            x2 = self.attns[1](qkv[:, :, :, C // 2:])
            attened_x = torch.cat([x1, x2], dim=2)
        else:
            attened_x = self.attns[0](qkv)
        attened_x = self.proj(attened_x)
        x = x + self.drop_path(attened_x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


def img2windows(img, H_sp, W_sp):
    """
    img: B C H W
    """
    B, C, H, W = img.shape
    img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
    img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C)
    return img_perm


def windows2img(img_splits_hw, H_sp, W_sp, H, W):
    """
    img_splits_hw: B' H W C
    """
    B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))

    img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
    img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return img


class Merge_Block(nn.Module):
    def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1)
        self.norm = norm_layer(dim_out)

    def forward(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
        x = self.conv(x)
        B, C = x.shape[:2]
        x = x.view(B, C, -1).transpose(-2, -1).contiguous()
        x = self.norm(x)

        return x

class CARAFE(nn.Module):
    def __init__(self, dim, dim_out, kernel_size=3, up_factor=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.up_factor = up_factor
        self.down = nn.Conv2d(dim, dim // 4, 1)
        self.encoder = nn.Conv2d(dim // 4, self.up_factor ** 2 * self.kernel_size ** 2,
                                 self.kernel_size, 1, self.kernel_size // 2)
        self.out = nn.Conv2d(dim, dim_out, 1)

    def forward(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)


            # N,C,H,W -> N,C,delta*H,delta*W
            # kernel prediction module
        kernel_tensor = self.down(x)  # (N, Cm, H, W)
        kernel_tensor = self.encoder(kernel_tensor)  # (N, S^2 * Kup^2, H, W)
        kernel_tensor = F.pixel_shuffle(kernel_tensor,
                                        self.up_factor)  # (N, S^2 * Kup^2, H, W)->(N, Kup^2, S*H, S*W)
        kernel_tensor = F.softmax(kernel_tensor, dim=1)  # (N, Kup^2, S*H, S*W)
        kernel_tensor = kernel_tensor.unfold(2, self.up_factor, step=self.up_factor)  # (N, Kup^2, H, W*S, S)
        kernel_tensor = kernel_tensor.unfold(3, self.up_factor, step=self.up_factor)  # (N, Kup^2, H, W, S, S)
        kernel_tensor = kernel_tensor.reshape(B, self.kernel_size ** 2, H, W,
                                                  self.up_factor ** 2)  # (N, Kup^2, H, W, S^2)
        kernel_tensor = kernel_tensor.permute(0, 2, 3, 1, 4)  # (N, H, W, Kup^2, S^2)

            # content-aware reassembly module
            # tensor.unfold: dim, size, step
        w = F.pad(x, pad=(self.kernel_size // 2, self.kernel_size // 2,
                                              self.kernel_size // 2, self.kernel_size // 2),
                              mode='constant', value=0)  # (N, C, H+Kup//2+Kup//2, W+Kup//2+Kup//2)
        w = w.unfold(2, self.kernel_size, step=1)  # (N, C, H, W+Kup//2+Kup//2, Kup)
        w = w.unfold(3, self.kernel_size, step=1)  # (N, C, H, W, Kup, Kup)
        w = w.reshape(B, C, H, W, -1)  # (N, C, H, W, Kup^2)
        w = w.permute(0, 2, 3, 1, 4)  # (N, H, W, C, Kup^2)

        x = torch.matmul(w, kernel_tensor)  # (N, H, W, C, S^2)
        x = x.reshape(B, H, W, -1)
        x = x.permute(0, 3, 1, 2)
        x = F.pixel_shuffle(x, self.up_factor)
        x = self.out(x)
        B, C = x.shape[:2]
        x = x.view(B, C, -1).transpose(-2, -1).contiguous()

        return x


class CARAFE4(nn.Module):
    def __init__(self, dim, dim_out, kernel_size=3, up_factor=4):
        super().__init__()
        self.kernel_size = kernel_size
        self.up_factor = up_factor
        self.down = nn.Conv2d(dim, dim // 4, 1)
        self.encoder = nn.Conv2d(dim // 4, self.up_factor ** 2 * self.kernel_size ** 2,
                                 self.kernel_size, 1, self.kernel_size // 2)
        self.out = nn.Conv2d(dim, dim_out, 1)

    def forward(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)


            # N,C,H,W -> N,C,delta*H,delta*W
            # kernel prediction module
        kernel_tensor = self.down(x)  # (N, Cm, H, W)
        kernel_tensor = self.encoder(kernel_tensor)  # (N, S^2 * Kup^2, H, W)
        kernel_tensor = F.pixel_shuffle(kernel_tensor,
                                        self.up_factor)  # (N, S^2 * Kup^2, H, W)->(N, Kup^2, S*H, S*W)
        kernel_tensor = F.softmax(kernel_tensor, dim=1)  # (N, Kup^2, S*H, S*W)
        kernel_tensor = kernel_tensor.unfold(2, self.up_factor, step=self.up_factor)  # (N, Kup^2, H, W*S, S)
        kernel_tensor = kernel_tensor.unfold(3, self.up_factor, step=self.up_factor)  # (N, Kup^2, H, W, S, S)
        kernel_tensor = kernel_tensor.reshape(B, self.kernel_size ** 2, H, W,
                                                  self.up_factor ** 2)  # (N, Kup^2, H, W, S^2)
        kernel_tensor = kernel_tensor.permute(0, 2, 3, 1, 4)  # (N, H, W, Kup^2, S^2)

            # content-aware reassembly module
            # tensor.unfold: dim, size, step
        w = F.pad(x, pad=(self.kernel_size // 2, self.kernel_size // 2,
                                              self.kernel_size // 2, self.kernel_size // 2),
                              mode='constant', value=0)  # (N, C, H+Kup//2+Kup//2, W+Kup//2+Kup//2)
        w = w.unfold(2, self.kernel_size, step=1)  # (N, C, H, W+Kup//2+Kup//2, Kup)
        w = w.unfold(3, self.kernel_size, step=1)  # (N, C, H, W, Kup, Kup)
        w = w.reshape(B, C, H, W, -1)  # (N, C, H, W, Kup^2)
        w = w.permute(0, 2, 3, 1, 4)  # (N, H, W, C, Kup^2)

        x = torch.matmul(w, kernel_tensor)  # (N, H, W, C, S^2)
        x = x.reshape(B, H, W, -1)
        x = x.permute(0, 3, 1, 2)
        x = F.pixel_shuffle(x, self.up_factor)
        x = self.out(x)
        B, C = x.shape[:2]
        x = x.view(B, C, -1).transpose(-2, -1).contiguous()

        return x


class CSWinTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=8, embed_dim=64, depth=[1, 2, 9, 1],
                 split_size=[1, 2, 7, 7],
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0, hybrid_backbone=None, norm_layer=nn.LayerNorm, use_chk=False):
        super().__init__()
        self.use_chk = use_chk
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        heads = num_heads

        #encoder

        self.stage1_conv_embed = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
            Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4),
            nn.LayerNorm(embed_dim)
        )

        curr_dim = embed_dim

        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))]  # stochastic depth decay rule
        print("depth",depth)
        self.stage1 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[0], reso=img_size // 4, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth[0])])
        self.merge1 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage2 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[1], reso=img_size // 8, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:1]) + i], norm_layer=norm_layer)
                for i in range(depth[1])])
        self.merge2 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        temp_stage3 = []
        temp_stage3.extend(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[2], reso=img_size // 16, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:2]) + i], norm_layer=norm_layer)
                for i in range(depth[2])])

        self.stage3 = nn.ModuleList(temp_stage3)
        self.merge3 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage4 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[3], reso=img_size // 32, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True)
                for i in range(depth[-1])])

        self.norm = norm_layer(curr_dim)

        # decoder


        self.stage_up4 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[3], reso=img_size // 32, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True)
                for i in range(depth[-1])])

        self.upsample4 = CARAFE(curr_dim, curr_dim // 2)
        curr_dim = curr_dim // 2

        self.concat_linear4 = nn.Linear(512, 256)
        self.stage_up3 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[2], reso=img_size // 16, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:2]) + i], norm_layer=norm_layer)
                for i in range(depth[2])]
        )

        self.upsample3 = CARAFE(curr_dim, curr_dim // 2)
        curr_dim = curr_dim // 2

        self.concat_linear3 = nn.Linear(256, 128)
        self.stage_up2 = nn.ModuleList(
            [CSWinBlock(
                dim=curr_dim, num_heads=heads[1], reso=img_size // 8, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[np.sum(depth[:1]) + i], norm_layer=norm_layer)
                for i in range(depth[1])])
        self.upsample2 = CARAFE(curr_dim, curr_dim // 2)
        curr_dim = curr_dim // 2

        self.concat_linear2 = nn.Linear(128, 64)
        self.stage_up1 = nn.ModuleList([
            CSWinBlock(
                dim=curr_dim, num_heads=heads[0], reso=img_size // 4, mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0],
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth[0])])

        self.upsample1 = CARAFE4(curr_dim, 64)
        self.norm_up = norm_layer(embed_dim)
        self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
        # Classifier head

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    #Encoder and Bottleneck
    def forward_features(self, x):
        x = self.stage1_conv_embed(x)

        x = self.pos_drop(x)

        for blk in self.stage1:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        self.x1 = x
        x = self.merge1(x)

        for blk in self.stage2:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        self.x2 = x
        x = self.merge2(x)

        for blk in self.stage3:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                    x = blk(x)
        self.x3 = x
        x = self.merge3(x)

        for blk in self.stage4:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)

        x = self.norm(x)

        return x

    #Dencoder and Skip connection
    def forward_up_features(self, x):
        for blk in self.stage_up4:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        x = self.upsample4(x)
        x = torch.cat([self.x3, x],-1)
        x = self.concat_linear4(x)
        for blk in self.stage_up3:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        # print("decoder stage3", x.shape)
        x = self.upsample3(x)
        x = torch.cat([self.x2, x],-1)
        x = self.concat_linear3(x)
        for blk in self.stage_up2:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                    x = blk(x)
        x = self.upsample2(x)
        x = torch.cat([self.x1, x],-1)
        x = self.concat_linear2(x)
        for blk in self.stage_up1:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        x = self.norm_up(x)  # B L C
        return x

    def up_x4(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = self.upsample1(x)
        x = x.view(B, 4 * H, 4 * W, -1)
        x = x.permute(0, 3, 1, 2)  # B,C,H,W
        x = self.output(x)

        return x

    def forward(self, x):
        x = self.forward_features(x)

        x = self.forward_up_features(x)

        x = self.up_x4(x)


        return x

Overwriting cswin_unet.py


In [14]:
%%writefile vision_transformer.py

import torch
import torch.nn as nn
import copy
from cswin_unet import CSWinTransformer

class CSwinUnet(nn.Module):
    def __init__(
        self,
        img_size=224,
        num_classes=1,
        patch_size=4,
        embed_dim=64,
        # embed_dim=96,
        depth=[1, 2, 9, 1],
        split_size=[1, 2, 7, 7],
        num_heads=[2, 4, 8, 16],
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        drop_path_rate=0.1,
        pretrained_ckpt=None,
    ):
        """
        Initialize the CSWinUnet model for deblurring.

        :param img_size: Size of input images (default is 224).
        :param num_classes: Number of output channels (1 for grayscale).
        :param patch_size: Patch size for the transformer (default is 4).
        :param embed_dim: Embedding dimension (default is 96).
        :param depth: Depth of each stage (list of integers).
        :param split_size: Split size for local attention at each stage.
        :param num_heads: Number of attention heads at each stage.
        :param mlp_ratio: MLP ratio in transformer blocks.
        :param qkv_bias: Whether to use bias in QKV projection.
        :param qk_scale: Scaling factor for QK projection (default is None).
        :param drop_rate: Dropout rate.
        :param drop_path_rate: Stochastic depth rate.
        :param pretrained_ckpt: Path to a pretrained checkpoint file (optional).
        """
        super(CSwinUnet, self).__init__()
        self.num_classes = num_classes

        # Initialize the CSWinTransformer model
        self.cswin_unet = CSWinTransformer(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=1,  # Grayscale images
            num_classes=self.num_classes,
            embed_dim=embed_dim,
            depth=depth,
            split_size=split_size,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
        )

        print("CSWinUnet model initialized.")

        # Load pretrained weights if provided
        if pretrained_ckpt:
            self.load_from(pretrained_ckpt)

    def forward(self, x):
        # if x.size()[1] == 1:
        #     x = x.repeat(1,3,1,1)
        logits = self.cswin_unet(x)
        return logits

    def load_from(self, pretrained_ckpt):
        """
        Load a pretrained model from a checkpoint file.

        :param pretrained_ckpt: Path to the pretrained checkpoint file.
        """
        print(f"Loading pretrained model from {pretrained_ckpt}")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        pretrained_dict = torch.load(pretrained_ckpt, map_location=device)

        # Ensure compatibility with current model structure
        pretrained_dict = pretrained_dict.get('state_dict_ema', pretrained_dict)  # Adjust key if nested
        model_dict = self.cswin_unet.state_dict()
        full_dict = copy.deepcopy(pretrained_dict)

        # Adjust stage keys for upsampling compatibility
        for k, v in pretrained_dict.items():
            if "stage" in k:
                current_k = "stage_up" + k[5:]
                full_dict[current_k] = v

        # Remove mismatched keys
        for k in list(full_dict.keys()):
            if k in model_dict and full_dict[k].shape != model_dict[k].shape:
                print(f"Deleting key: {k}, shapes don't match (pretrained: {full_dict[k].shape}, model: {model_dict[k].shape})")
                del full_dict[k]

        # Load state dict into the model
        msg = self.cswin_unet.load_state_dict(full_dict, strict=False)
        print(f"Model loaded successfully: {msg}")


Overwriting vision_transformer.py


In [15]:
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class PairedXrayDataset(Dataset):
    def __init__(self, blurred_dir, clear_dir, transform=None):
        """
        Args:
            blurred_dir (str): Directory with blurred images.
            clear_dir (str): Directory with corresponding clear images.
            transform (callable, optional): Optional transformations to apply to both blurred and clear images.
        """
        # Verify directories exist
        assert os.path.isdir(blurred_dir), f"Directory not found: {blurred_dir}"
        assert os.path.isdir(clear_dir), f"Directory not found: {clear_dir}"
        
        self.blurred_dir = blurred_dir
        self.clear_dir = clear_dir
        self.transform = transform
        self.blurred_images = sorted(os.listdir(blurred_dir))
        self.clear_images = sorted(os.listdir(clear_dir))

        # Ensure both directories have the same number of images
        assert len(self.blurred_images) == len(self.clear_images), \
            "Mismatched number of images in blurred and clear directories."

    def __len__(self):
        return len(self.blurred_images)

    def __getitem__(self, idx):
        # Get file paths for blurred and clear images
        blurred_path = os.path.join(self.blurred_dir, self.blurred_images[idx])
        clear_path = os.path.join(self.clear_dir, self.clear_images[idx])

        # Load images and convert to grayscale if necessary
        blurred_image = Image.open(blurred_path).convert("L")  # "L" mode for grayscale
        clear_image = Image.open(clear_path).convert("L")

        # Apply transformations if provided
        if self.transform:
            blurred_image = self.transform(blurred_image)
            clear_image = self.transform(clear_image)
        else:
            # Default transformations: resize and convert to tensor
            default_transform = transforms.Compose([
                transforms.Resize((224, 224)),  # Resize to match model input size
                transforms.ToTensor(),
            ])
            blurred_image = default_transform(blurred_image)
            clear_image = default_transform(clear_image)

        return blurred_image, clear_image




In [16]:
!pip install pytorch-msssim



In [17]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchmetrics.functional import structural_similarity_index_measure as ssim
import torch
import torch.nn as nn
# from pytorch_msssim import ssim

# class CharbonnierLoss(nn.Module):
#     """Charbonnier Loss (L1)"""
#     def __init__(self, eps=1e-6, out_norm='bci'):
#         super(CharbonnierLoss, self).__init__()
#         self.eps = eps
#         self.out_norm = out_norm

#     def forward(self, x, y):
#         def get_outnorm(tensor: torch.Tensor, out_norm: str):
#             img_shape = tensor.shape
#             norm = 1
#             if 'b' in out_norm:
#                 norm /= img_shape[0]  # normalize by batch size
#             if 'c' in out_norm:
#                 norm /= img_shape[-3]  # normalize by channels
#             if 'i' in out_norm:
#                 norm /= img_shape[-1] * img_shape[-2]  # normalize by image size
#             return norm

#         norm = get_outnorm(x, self.out_norm)
#         loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
#         return loss * norm

# class PerceptualLoss(nn.Module):
#     """Perceptual Loss using VGG features"""
#     def __init__(self):
#         super(PerceptualLoss, self).__init__()
#         vgg = models.vgg19(pretrained=True).features.eval()
#         self.layers = [0, 5, 10, 19, 28]  # Conv layers for perceptual features
#         self.vgg_layers = nn.ModuleList([vgg[i] for i in self.layers])
#         for param in self.vgg_layers.parameters():
#             param.requires_grad = False

#     def forward(self, generated, target):
#         loss = 0.0
#         x_gen, x_target = generated, target
#         for layer in self.vgg_layers:
#             x_gen, x_target = layer(x_gen), layer(x_target)
#             loss += nn.functional.l1_loss(x_gen, x_target)
#         return loss

# # Combined Loss with Charbonnier, SSIM, and Perceptual Loss
# class CombinedLoss(nn.Module):
#     def __init__(self, perceptual_model, charbonnier_eps=1e-6, perceptual_weight=0.1, ssim_weight=0.1):
#         super(CombinedLoss, self).__init__()
#         self.charbonnier_loss = CharbonnierLoss(eps=charbonnier_eps)
#         self.perceptual_model = perceptual_model
#         self.perceptual_model.eval()  # Set VGG to evaluation mode
#         for param in self.perceptual_model.parameters():
#             param.requires_grad = False
#         self.perceptual_weight = perceptual_weight
#         self.ssim_weight = ssim_weight

#     def forward(self, output, target):
#         # Ensure 3 channels for perceptual loss
#         if output.size(1) == 1:  # Check if input has 1 channel
#             output = output.repeat(1, 3, 1, 1)  # Duplicate channels
#         if target.size(1) == 1:
#             target = target.repeat(1, 3, 1, 1)

#         # Charbonnier loss
#         l_charbonnier = self.charbonnier_loss(output, target)

#         # SSIM loss
#         l_ssim = 1 - ssim(output, target)

#         # Perceptual loss
#         output_features = self.perceptual_model(output)
#         target_features = self.perceptual_model(target)
#         l_perceptual = torch.mean((output_features - target_features).pow(2))

#         # Weighted combined loss
#         loss = l_charbonnier + self.ssim_weight * l_ssim + self.perceptual_weight * l_perceptual
#         return loss


# class MAELoss(nn.Module):
#     """Mean Absolute Error Loss"""
#     def __init__(self):
#         super(MAELoss, self).__init__()

#     def forward(self, x, y):
#         return torch.mean(torch.abs(x - y))


# class CombinedLoss(nn.Module):
#     def __init__(self, mae_weight=1.0, ssim_weight=0.1):
#         super(CombinedLoss, self).__init__()
#         self.mae_loss = MAELoss()
#         self.ssim_weight = ssim_weight

#     def forward(self, output, target):
#         # Ensure the input images are grayscale (single channel), no need to repeat channels for SSIM
#         if output.size(1) == 1 and target.size(1) == 1:
#             # MAE loss (pixel-wise difference)
#             l_mae = self.mae_loss(output, target)

#             # SSIM loss (1 - SSIM to make it a loss function)
#             l_ssim = 1 - ssim(output, target)

#         else:
#             raise ValueError("Both output and target should have 1 channel (grayscale images).")

#         # Weighted combined loss
#         loss = l_mae + self.ssim_weight * l_ssim
#         return loss


def trainer_synapse(args, model, snapshot_path, dataloader):
    """
    Training function for deblurring model using Combined loss.
    Args:
        args: Command-line arguments.
        model: The initialized model (e.g., TransUNet).
        snapshot_path: Path to save model checkpoints.
        dataloader: DataLoader providing training batches of blurred and clear images.
    """
    # Create snapshot directory if it doesn't exist
    os.makedirs(snapshot_path, exist_ok=True)

    # Set up Combined loss, optimizer, and learning rate scheduler
    combined_loss = CombinedLoss(lambda_ssim=args.lambda_ssim, lambda_perceptual=args.lambda_perceptual)
    optimizer = optim.AdamW(model.parameters(), lr=args.base_lr, betas=(0.9, 0.999), weight_decay=0.0001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epochs, eta_min=1e-6)

    # Enable GPU support if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Set the model to training mode
    model.train()

    for epoch in range(args.max_epochs):
        running_loss = 0.0
        with tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}/{args.max_epochs}", unit="batch") as pbar:
            for i, (blurred_image, clear_image) in enumerate(dataloader):
                blurred_image = blurred_image.to(device)
                clear_image = clear_image.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                output = model(blurred_image)

                # Calculate losses
                loss = combined_loss(output, clear_image)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                # Update running loss
                running_loss += loss.item()
                pbar.set_postfix({'Loss': running_loss / (i + 1)})
                pbar.update(1)

            # Step the scheduler at the end of each epoch
            scheduler.step()

        # Save the model checkpoint every few epochs
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(snapshot_path, f"epoch_{epoch+1}.pth")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model checkpoint saved at {checkpoint_path}")

    # Final save after training
    final_model_path = os.path.join(snapshot_path, "final_model.pth")
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model saved at {final_model_path}")


In [21]:
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from vision_transformer import CSwinUnet
# from PairedXrayimages import PairedXrayDataset
from tensorboardX import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_msssim import ssim
from tqdm import tqdm
from torchvision.models import vgg19
import matplotlib.pyplot as plt
import pandas as pd

class Args:
    def __init__(self):
        # Data directories
        self.clear_dir = '/kaggle/input/ch-xrays/augmented/augmented'
        self.blurred_dir = '/kaggle/input/ch-xrays/aug_blurred/aug_blurred'
        
        # Model configuration
        self.num_classes = 1
        self.img_size = 224
        
        # Training parameters
        self.max_epochs = 150
        self.batch_size = 8
        self.base_lr = 0.0005
        self.deterministic = 1
        self.seed = 1234
        self.ssim_weight = 0.5
        self.perceptual_weight = 0.1
        
        # Output paths
        self.snapshot_path = '/kaggle/working/model_checkpoints'
        self.plots_dir = '/kaggle/working/training_plots'

# # Define combined loss
# class CombinedLoss:
#     def __init__(self, mae_weight=0.4, ssim_weight=0.6):
#         self.mae_weight = mae_weight
#         self.ssim_weight = ssim_weight
#         self.mae_loss = nn.L1Loss()

#     def __call__(self, output, target):
#         mae = self.mae_loss(output, target)
#         ssim_loss = 1 - ssim(output, target, data_range=1.0, size_average=True)
#         return self.mae_weight * mae + self.ssim_weight * ssim_loss


class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""
    def __init__(self, eps=1e-6, out_norm='bci'):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps
        self.out_norm = out_norm

    def forward(self, x, y):
        def get_outnorm(tensor: torch.Tensor, out_norm: str):
            img_shape = tensor.shape
            norm = 1
            if 'b' in out_norm:
                norm /= img_shape[0]  # normalize by batch size
            if 'c' in out_norm:
                norm /= img_shape[-3]  # normalize by channels
            if 'i' in out_norm:
                norm /= img_shape[-1] * img_shape[-2]  # normalize by image size
            return norm

        norm = get_outnorm(x, self.out_norm)
        loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2))
        return loss * norm

class PerceptualLoss(nn.Module):
    """Perceptual Loss using VGG features"""
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features.eval()
        self.layers = [0, 5, 10, 19, 28]  # Conv layers for perceptual features
        self.vgg_layers = nn.ModuleList([vgg[i] for i in self.layers])
        for param in self.vgg_layers.parameters():
            param.requires_grad = False

    def forward(self, generated, target):
        loss = 0.0
        x_gen, x_target = generated, target
        for layer in self.vgg_layers:
            x_gen, x_target = layer(x_gen), layer(x_target)
            loss += nn.functional.l1_loss(x_gen, x_target)
        return loss

# Combined Loss with Charbonnier, SSIM, and Perceptual Loss
class CombinedLoss(nn.Module):
    def __init__(self, perceptual_model, charbonnier_eps=1e-6, perceptual_weight=0.1, ssim_weight=0.1):
        super(CombinedLoss, self).__init__()
        self.charbonnier_loss = CharbonnierLoss(eps=charbonnier_eps)
        self.perceptual_model = perceptual_model
        self.perceptual_model.eval()  # Set VGG to evaluation mode
        for param in self.perceptual_model.parameters():
            param.requires_grad = False
        self.perceptual_weight = perceptual_weight
        self.ssim_weight = ssim_weight

    def forward(self, output, target):
        # Ensure 3 channels for perceptual loss
        if output.size(1) == 1:  # Check if input has 1 channel
            output = output.repeat(1, 3, 1, 1)  # Duplicate channels
        if target.size(1) == 1:
            target = target.repeat(1, 3, 1, 1)

        # Charbonnier loss
        l_charbonnier = self.charbonnier_loss(output, target)

        # SSIM loss
        l_ssim = 1 - ssim(output, target)

        # Perceptual loss
        output_features = self.perceptual_model(output)
        target_features = self.perceptual_model(target)
        l_perceptual = torch.mean((output_features - target_features).pow(2))

        # Weighted combined loss
        loss = l_charbonnier + self.ssim_weight * l_ssim + self.perceptual_weight * l_perceptual
        return loss

def save_loss_plot(losses, plots_dir):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(losses) + 1), losses, 'b-')
    plt.title('Training Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, 'loss_plot.png'))
    plt.close()

def train_model(args, model, dataloader, optimizer, scheduler, combined_loss, device):
    model.train()
    scaler = GradScaler()
    writer = SummaryWriter(logdir=args.snapshot_path)
    
    # Lists to store losses
    epoch_losses = []
    # mae_weight = 0.4
    # ssim_weight = 0.6
    # combined_loss = CombinedLoss(mae_weight, ssim_weight)

    # Early stopping parameters
    patience = 20
    best_loss = float('inf')
    patience_counter = 0

    # Create directory for plots
    os.makedirs(args.plots_dir, exist_ok=True)

    for epoch in range(args.max_epochs):
        running_loss = 0.0
        with tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}/{args.max_epochs}", unit="batch") as pbar:
            for i, (blurred, clear) in enumerate(dataloader):
                blurred, clear = blurred.to(device), clear.to(device)
                optimizer.zero_grad()

                with autocast():
                    output = model(blurred)
                    loss = combined_loss(output, clear)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                running_loss += loss.item()
                pbar.set_postfix({'Loss': running_loss / (i + 1)})
                pbar.update(1)

            scheduler.step()
            
            # Calculate and store average epoch loss
            avg_epoch_loss = running_loss / len(dataloader)
            epoch_losses.append(avg_epoch_loss)

            # # Dynamic weight adjustment
            # if epoch % 10 == 0 and epoch > 0:
            #     if epoch_losses[-1] < best_loss:
            #         if avg_epoch_loss - best_loss > 0.01:  # SSIM improving but PSNR poor
            #             mae_weight = min(1.0, mae_weight + 0.05)
            #         elif best_loss - avg_epoch_loss > 0.01:  # PSNR improving but SSIM stagnant
            #             ssim_weight = min(1.0, ssim_weight + 0.05)
            #         combined_loss = CombinedLoss(mae_weight, ssim_weight)

            # Save epoch losses and plot
            df_epochs = pd.DataFrame({
                'Epoch': range(1, len(epoch_losses) + 1),
                'Loss': epoch_losses
            })
            df_epochs.to_csv(os.path.join(args.plots_dir, 'epoch_losses.csv'), index=False)
            save_loss_plot(epoch_losses, args.plots_dir)

            writer.add_scalar('Training Loss', avg_epoch_loss, epoch)

            # Early stopping check
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch + 1}.")
                break

            # Save checkpoint with only model_state_dict
            if (epoch + 1) % 10 == 0:
                checkpoint_path = os.path.join(args.snapshot_path, f"epoch_{epoch+1}.pth")
                torch.save(model.state_dict(), checkpoint_path)
                print(f"\nCheckpoint saved: {checkpoint_path}")

    # Save final model with only model_state_dict
    final_model_path = os.path.join(args.snapshot_path, "final_model.pth")
    torch.save(model.state_dict(), final_model_path)
    
    print(f"\nTraining completed. Model saved to {final_model_path}")
    print(f"Best loss: {min(epoch_losses):.6f} at epoch {epoch_losses.index(min(epoch_losses)) + 1}")

def main():
    args = Args()

    # Set deterministic behavior
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    # Create snapshot directory
    os.makedirs(args.snapshot_path, exist_ok=True)

    # Set up logging
    logging.basicConfig(
        filename=os.path.join(args.snapshot_path, "train.log"),
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )
    logging.info("Training started with configuration: %s", vars(args))

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize model
    model = CSwinUnet(img_size=args.img_size, num_classes=args.num_classes).to(device)
    logging.info("Model initialized")

    # Initialize VGG for perceptual loss
    vgg = vgg19(pretrained=True).features[:9].to(device).eval()
    for param in vgg.parameters():
        param.requires_grad = False

    # Initialize combined loss
    combined_loss = CombinedLoss(
        perceptual_model=vgg,
        ssim_weight=args.ssim_weight,
        perceptual_weight=args.perceptual_weight
    )
    logging.info("Loss functions initialized")

    # Initialize optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=args.base_lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.max_epochs, eta_min=1e-6)

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor()
    ])

    # Initialize dataset and dataloader
    dataset = PairedXrayDataset(
        blurred_dir=args.blurred_dir,
        clear_dir=args.clear_dir,
        transform=transform
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    logging.info(f"Dataloader initialized with {len(dataset)} samples")

    # Start training
    train_model(args, model, dataloader, optimizer, scheduler, combined_loss, device)
    logging.info("Training completed")

if __name__ == "__main__":
    main()

Using device: cuda
depth [1, 2, 9, 1]
CSWinUnet model initialized.


  scaler = GradScaler()
  with autocast():
Epoch 1/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0905]
Epoch 2/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0404]
Epoch 3/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0374]
Epoch 4/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0326]
Epoch 5/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0343]
Epoch 6/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0332]
Epoch 7/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.0307]
Epoch 8/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.0295]
Epoch 9/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0298]
Epoch 10/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0294]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_10.pth


Epoch 11/150: 100%|██████████| 160/160 [00:38<00:00,  4.17batch/s, Loss=0.0287]
Epoch 12/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.028] 
Epoch 13/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0275]
Epoch 14/150: 100%|██████████| 160/160 [00:38<00:00,  4.19batch/s, Loss=0.0271]
Epoch 15/150: 100%|██████████| 160/160 [00:38<00:00,  4.19batch/s, Loss=0.0263]
Epoch 16/150: 100%|██████████| 160/160 [00:38<00:00,  4.19batch/s, Loss=0.0272]
Epoch 17/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0247]
Epoch 18/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0266]
Epoch 19/150: 100%|██████████| 160/160 [00:38<00:00,  4.20batch/s, Loss=0.0242]
Epoch 20/150: 100%|██████████| 160/160 [00:38<00:00,  4.18batch/s, Loss=0.0244]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_20.pth


Epoch 21/150: 100%|██████████| 160/160 [00:38<00:00,  4.18batch/s, Loss=0.0261]
Epoch 22/150: 100%|██████████| 160/160 [00:38<00:00,  4.20batch/s, Loss=0.0231]
Epoch 23/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0236]
Epoch 24/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0249]
Epoch 25/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0249]
Epoch 26/150: 100%|██████████| 160/160 [00:38<00:00,  4.19batch/s, Loss=0.0233]
Epoch 27/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0233]
Epoch 28/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.022] 
Epoch 29/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.021] 
Epoch 30/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0215]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_30.pth


Epoch 31/150: 100%|██████████| 160/160 [00:38<00:00,  4.20batch/s, Loss=0.0216]
Epoch 32/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0218]
Epoch 33/150: 100%|██████████| 160/160 [00:37<00:00,  4.21batch/s, Loss=0.0205]
Epoch 34/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0219]
Epoch 35/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0212]
Epoch 36/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0203]
Epoch 37/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0202]
Epoch 38/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0198]
Epoch 39/150: 100%|██████████| 160/160 [00:38<00:00,  4.21batch/s, Loss=0.02]  
Epoch 40/150: 100%|██████████| 160/160 [00:38<00:00,  4.21batch/s, Loss=0.0186]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_40.pth


Epoch 41/150: 100%|██████████| 160/160 [00:38<00:00,  4.13batch/s, Loss=0.0182]
Epoch 42/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0192]
Epoch 43/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0183]
Epoch 44/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0187]
Epoch 45/150: 100%|██████████| 160/160 [00:38<00:00,  4.20batch/s, Loss=0.0183]
Epoch 46/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0176]
Epoch 47/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.017] 
Epoch 48/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.017] 
Epoch 49/150: 100%|██████████| 160/160 [00:38<00:00,  4.19batch/s, Loss=0.0165]
Epoch 50/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0165]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_50.pth


Epoch 51/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0165]
Epoch 52/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.0172]
Epoch 53/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0174]
Epoch 54/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0159]
Epoch 55/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0169]
Epoch 56/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0161]
Epoch 57/150: 100%|██████████| 160/160 [00:38<00:00,  4.15batch/s, Loss=0.0156]
Epoch 58/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0156]
Epoch 59/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0156]
Epoch 60/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0155]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_60.pth


Epoch 61/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0157]
Epoch 62/150: 100%|██████████| 160/160 [00:38<00:00,  4.20batch/s, Loss=0.0145]
Epoch 63/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0159]
Epoch 64/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0147]
Epoch 65/150: 100%|██████████| 160/160 [00:38<00:00,  4.21batch/s, Loss=0.0145]
Epoch 66/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.014] 
Epoch 67/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0145]
Epoch 68/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0146]
Epoch 69/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0136]
Epoch 70/150: 100%|██████████| 160/160 [00:38<00:00,  4.17batch/s, Loss=0.0136]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_70.pth


Epoch 71/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0138]
Epoch 72/150: 100%|██████████| 160/160 [00:37<00:00,  4.21batch/s, Loss=0.0134]
Epoch 73/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0134]
Epoch 74/150: 100%|██████████| 160/160 [00:38<00:00,  4.18batch/s, Loss=0.013] 
Epoch 75/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0131]
Epoch 76/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.0135]
Epoch 77/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0128]
Epoch 78/150: 100%|██████████| 160/160 [00:37<00:00,  4.21batch/s, Loss=0.0126]
Epoch 79/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0129]
Epoch 80/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.0129]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_80.pth


Epoch 81/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.0122]
Epoch 82/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0121]
Epoch 83/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0121]
Epoch 84/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.0119]
Epoch 85/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.0119]
Epoch 86/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.0118]
Epoch 87/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0117]
Epoch 88/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.0118]
Epoch 89/150: 100%|██████████| 160/160 [00:37<00:00,  4.24batch/s, Loss=0.0116]
Epoch 90/150: 100%|██████████| 160/160 [00:37<00:00,  4.22batch/s, Loss=0.0115]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_90.pth


Epoch 91/150: 100%|██████████| 160/160 [00:38<00:00,  4.16batch/s, Loss=0.0115]
Epoch 92/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0113]
Epoch 93/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.0112]
Epoch 94/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0114]
Epoch 95/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0111]
Epoch 96/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0109]
Epoch 97/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0109]
Epoch 98/150: 100%|██████████| 160/160 [00:37<00:00,  4.31batch/s, Loss=0.0108]
Epoch 99/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.0107]
Epoch 100/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0111]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_100.pth


Epoch 101/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0107]
Epoch 102/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0105]
Epoch 103/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0104]
Epoch 104/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.0105]
Epoch 105/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.0104]
Epoch 106/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.0103]
Epoch 107/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0103]
Epoch 108/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.0103]
Epoch 109/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0103]
Epoch 110/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.0101]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_110.pth


Epoch 111/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0102]
Epoch 112/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0101] 
Epoch 113/150: 100%|██████████| 160/160 [00:37<00:00,  4.31batch/s, Loss=0.0101] 
Epoch 114/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.01]  
Epoch 115/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00994]
Epoch 116/150: 100%|██████████| 160/160 [00:37<00:00,  4.31batch/s, Loss=0.00992]
Epoch 117/150: 100%|██████████| 160/160 [00:37<00:00,  4.31batch/s, Loss=0.00987]
Epoch 118/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.00985]
Epoch 119/150: 100%|██████████| 160/160 [00:37<00:00,  4.31batch/s, Loss=0.00983]
Epoch 120/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.00981]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_120.pth


Epoch 121/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.00976]
Epoch 122/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.00978]
Epoch 123/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00975]
Epoch 124/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00971]
Epoch 125/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00967]
Epoch 126/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.00967]
Epoch 127/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.00967]
Epoch 128/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00963]
Epoch 129/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.00959]
Epoch 130/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.00957]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_130.pth


Epoch 131/150: 100%|██████████| 160/160 [00:38<00:00,  4.19batch/s, Loss=0.00956]
Epoch 132/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.00955]
Epoch 133/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.00957]
Epoch 134/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00953]
Epoch 135/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.00954]
Epoch 136/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.0095] 
Epoch 137/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.00948]
Epoch 138/150: 100%|██████████| 160/160 [00:37<00:00,  4.25batch/s, Loss=0.0095] 
Epoch 139/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.0095] 
Epoch 140/150: 100%|██████████| 160/160 [00:37<00:00,  4.23batch/s, Loss=0.00948]



Checkpoint saved: /kaggle/working/model_checkpoints/epoch_140.pth


Epoch 141/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.00949]
Epoch 142/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.00945]
Epoch 143/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.00945]
Epoch 144/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00947]
Epoch 145/150: 100%|██████████| 160/160 [00:37<00:00,  4.30batch/s, Loss=0.00946]
Epoch 146/150: 100%|██████████| 160/160 [00:37<00:00,  4.28batch/s, Loss=0.00945]
Epoch 147/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00947]
Epoch 148/150: 100%|██████████| 160/160 [00:37<00:00,  4.29batch/s, Loss=0.00944]
Epoch 149/150: 100%|██████████| 160/160 [00:37<00:00,  4.27batch/s, Loss=0.00943]
Epoch 150/150: 100%|██████████| 160/160 [00:37<00:00,  4.26batch/s, Loss=0.00943]


Checkpoint saved: /kaggle/working/model_checkpoints/epoch_150.pth

Training completed. Model saved to /kaggle/working/model_checkpoints/final_model.pth
Best loss: 0.009431 at epoch 150





In [30]:
import os
import torch
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from pytorch_msssim import ssim
from skimage.metrics import peak_signal_noise_ratio as calculate_psnr
from PIL import Image
from vision_transformer import CSwinUnet

class TestArgs:
    def __init__(self):
        self.test_blurred_dir = '/kaggle/input/ch-xrays-test/blurredd/blurredd'
        self.test_sharp_dir = '/kaggle/input/ch-xrays-test/sharpp/sharpp'
        self.generated_dir = '/kaggle/working/generated'
        self.comparison_dir = '/kaggle/working/comparison'  # Add this line for comparison folder
        self.model_path = '/kaggle/working/model_checkpoints/final_model.pth'
        self.img_size = 224
        self.batch_size = 1

# Test dataset
class TestXrayDataset(torch.utils.data.Dataset):
    def __init__(self, blurred_dir, sharp_dir, transform=None):
        self.blurred_dir = blurred_dir
        self.sharp_dir = sharp_dir
        self.image_paths = [os.path.join(blurred_dir, f) for f in os.listdir(blurred_dir) if f.endswith(('.PNG', '.jpg', '.jpeg'))]  # .PNG extension handled
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        blurred_img_path = self.image_paths[idx]
        sharp_img_path = os.path.join(self.sharp_dir, os.path.basename(blurred_img_path))
        
        blurred_img = Image.open(blurred_img_path).convert('L')  # Convert to grayscale
        sharp_img = Image.open(sharp_img_path).convert('L')  # Convert to grayscale

        if self.transform:
            blurred_img = self.transform(blurred_img)
            sharp_img = self.transform(sharp_img)

        return blurred_img, sharp_img, blurred_img_path

import matplotlib.pyplot as plt

@torch.no_grad()
def test_model(args, model, dataloader, device):
    model.eval()
    os.makedirs(args.generated_dir, exist_ok=True)
    os.makedirs(args.comparison_dir, exist_ok=True)  # Create comparison folder

    psnr_values = []
    ssim_values = []

    transform_to_pil = transforms.ToPILImage()

    for blurred, sharp, img_path in dataloader:
        blurred, sharp = blurred.to(device), sharp.to(device)
        
        # Get the deblurred output
        deblurred = model(blurred)
        deblurred = torch.clamp(deblurred, 0, 1)  # Ensure values are in [0, 1]

        # Save the deblurred image
        output_image = transform_to_pil(deblurred.squeeze(0).cpu())
        output_path = os.path.join(args.generated_dir, os.path.basename(img_path[0]))
        output_image.save(output_path)

        # Load the blurred and ground truth images to prepare the comparison image
        blurred_image = transform_to_pil(blurred.squeeze(0).cpu())
        sharp_image = transform_to_pil(sharp.squeeze(0).cpu())

        # Create a figure for comparison
        fig, ax = plt.subplots(1, 3, figsize=(12, 4))

        # Plot the images side by side
        ax[0].imshow(blurred_image, cmap='gray')
        ax[0].set_title('Blurred')
        ax[0].axis('off')

        ax[1].imshow(output_image, cmap='gray')
        ax[1].set_title('Generated')
        ax[1].axis('off')

        ax[2].imshow(sharp_image, cmap='gray')
        ax[2].set_title('Ground Truth')
        ax[2].axis('off')

        # Save the comparison image
        comparison_image_path = os.path.join(args.comparison_dir, os.path.basename(img_path[0]).replace(".PNG", "_comparison.png"))
        plt.tight_layout()
        plt.savefig(comparison_image_path)
        plt.close(fig)

        # Calculate PSNR and SSIM
        sharp_np = sharp.squeeze(0).cpu().numpy().transpose(1, 2, 0)  # Convert to HxWxC
        deblurred_np = deblurred.squeeze(0).cpu().numpy().transpose(1, 2, 0)  # Convert to HxWxC

        psnr_value = calculate_psnr(sharp_np, deblurred_np, data_range=1.0)
        ssim_value = ssim(torch.tensor(sharp_np).unsqueeze(0).unsqueeze(0), torch.tensor(deblurred_np).unsqueeze(0).unsqueeze(0), data_range=1.0, size_average=True).item()

        psnr_values.append(psnr_value)
        ssim_values.append(ssim_value)

        print(f"Processed: {os.path.basename(img_path[0])} | PSNR: {psnr_value:.4f} | SSIM: {ssim_value:.4f}")

    # Calculate and print average PSNR and SSIM
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)

    print("\nTest Completed")
    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")

    return avg_psnr, avg_ssim


def main():
    args = TestArgs()

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load the model
    model = CSwinUnet(img_size=args.img_size, num_classes=1).to(device)
    model.load_state_dict(torch.load(args.model_path, map_location=device))
    print("Model loaded successfully")

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor()
    ])

    # Initialize dataset and dataloader
    test_dataset = TestXrayDataset(
        blurred_dir=args.test_blurred_dir,
        sharp_dir=args.test_sharp_dir,
        transform=transform
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    print(f"Testing on {len(test_dataset)} images")

    # Test the model
    avg_psnr, avg_ssim = test_model(args, model, test_dataloader, device)

    print("\nResults:")
    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()


Using device: cuda
depth [1, 2, 9, 1]
CSWinUnet model initialized.
Model loaded successfully
Testing on 100 images


  model.load_state_dict(torch.load(args.model_path, map_location=device))


Processed: 258_img_258_original_person496_bacteria_2095.jpeg_2e6bf10d-9f00-44c3-a8c4-3706a9b7113b.PNG | PSNR: 33.2562 | SSIM: 0.9246
Processed: 259_img_259_original_person498_bacteria_2100.jpeg_30d2d192-3672-4e3f-8de1-fd8b4f03b152.PNG | PSNR: 33.9725 | SSIM: 0.9286
Processed: 271_img_271_original_person707_virus_1305.jpeg_4e02b1ce-0e13-4f46-82bc-b551226f6ea9.PNG | PSNR: 34.6335 | SSIM: 0.9437
Processed: 264_img_264_original_person593_bacteria_2435.jpeg_14025f99-7553-4fa8-818d-c4af7614859a.PNG | PSNR: 37.7017 | SSIM: 0.9526
Processed: 275_img_275_original_person802_bacteria_2708.jpeg_ed48aec4-f0b2-4d13-a64b-bf99aebe2e9b.PNG | PSNR: 33.3045 | SSIM: 0.9248
Processed: 257_img_257_original_person469_virus_965.jpeg_5d7a5d2c-87a1-4765-acc8-d365b68ffaee.PNG | PSNR: 35.8902 | SSIM: 0.9033
Processed: 264_img_264_original_person593_bacteria_2435.jpeg_92ac4fa2-2769-4226-b9ea-0a73b24b4c40.PNG | PSNR: 37.4689 | SSIM: 0.9508
Processed: 265_img_265_original_person609_virus_1176.jpeg_c3eb4017-782e-4662

In [26]:
import shutil

# Path to the generated folder
generated_dir = '/kaggle/working/generated'

# Create a zip file
shutil.make_archive(generated_dir, 'zip', generated_dir)

'/kaggle/working/generated.zip'

In [31]:
import shutil

# Path to the generated folder
generated_dir = '/kaggle/working/comparison'

# Create a zip file
shutil.make_archive(generated_dir, 'zip', generated_dir)

'/kaggle/working/comparison.zip'