# Download Dataset

In [None]:
!pip install kaggle timm --quiet

!mkdir -p ~/.kaggle
!echo '{"username":"hafijulhoque987 ","key":"ba267fc402273b17f82059844f85fe32"}' > ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

# Download the dataset
!kaggle datasets download -d andrewmvd/pediatric-pneumonia-chest-xray

# Extract the dataset
import zipfile

with zipfile.ZipFile('pediatric-pneumonia-chest-xray.zip', 'r') as zip_ref:
    zip_ref.extractall('./pediatric_pneumonia')

print("Dataset downloaded and extracted successfully!")


In [None]:
import torch
# Set up device for training (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Transformer Classes

In [None]:
import copy
import math
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

        # Multi-head attention module
        # self.attention = nn.MultiheadAttention(embed_dim=2 * num_pos_feats, num_heads=num_heads)

    def forward(self, x, mask=None):
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack(
            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos_y = torch.stack(
            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

        
        # # Flatten for attention: [batch_size, 2*num_pos_feats, height, width] -> [height*width, batch_size, 2*num_pos_feats]
        # pos_flat = pos.flatten(2).permute(2, 0, 1)

        # # Apply multi-head attention
        # pos_weighted, _ = self.attention(pos_flat, pos_flat, pos_flat)

        # # Reshape back to [batch_size, 2*num_pos_feats, height, width]
        # pos_weighted = pos_weighted.permute(1, 2, 0).view(pos.size(0), pos.size(1), pos.size(2), pos.size(3))

        # return pos_weighted

class Transformer(nn.Module):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        normalize_before=False,
        return_intermediate_dec=False,
    ):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
        )
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
        )
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(
            decoder_layer,
            num_decoder_layers,
            decoder_norm,
            return_intermediate=return_intermediate_dec,
        )

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, query_pos, value_pos):
        # flatten NxCxHxW to HWxNxC
        bs, c, l = query_embed.shape
        query_embed = query_embed.permute(2, 0, 1)
        query_pos = query_pos.unsqueeze(0).expand(bs,-1,-1).permute(2, 0, 1)
        value=src.flatten(2).permute(2, 0, 1)
        value_pos=value_pos.flatten(2).permute(2, 0, 1)
        if mask is not None:
            mask = mask.flatten(1)

        #tgt = torch.zeros_like(query_embed)
        memory = self.encoder(value, src_key_padding_mask=mask, pos=value_pos)
        hs = self.decoder(
            tgt=query_embed, memory=memory, memory_key_padding_mask=mask, pos=value_pos, query_pos=query_pos
        )
        return hs.transpose(1, 2).transpose(0, 2)


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(
        self,
        src,
        mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
        output = src

        for layer in self.layers:
            output = layer(
                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
            )

        if self.norm is not None:
            output = self.norm(output)

        return output


class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        output = tgt

        intermediate = []

        for layer in self.layers:
            output = layer(
                tgt=output,
                memory=memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                pos=pos,
                query_pos=query_pos,
            )
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        normalize_before=False,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(
            q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(
            q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(
        self,
        src,
        src_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
    ):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)


class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model,
        nhead,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        normalize_before=False,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(
            q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(
            query=self.with_pos_embed(tgt, query_pos),
            key=self.with_pos_embed(memory, pos),
            value=memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(
            q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(
            query=self.with_pos_embed(tgt2, query_pos),
            key=self.with_pos_embed(memory, pos),
            value=memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        if self.normalize_before:
            return self.forward_pre(
                tgt,
                memory,
                tgt_mask,
                memory_mask,
                tgt_key_padding_mask,
                memory_key_padding_mask,
                pos,
                query_pos,
            )
        return self.forward_post(
            tgt,
            memory,
            tgt_mask,
            memory_mask,
            tgt_key_padding_mask,
            memory_key_padding_mask,
            pos,
            query_pos,
        )


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")

# UNET Classes

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, feature_dims):
        super(UNet, self).__init__()

        self.encoder1 = DoubleConv(in_channels, feature_dims)
        self.encoder2 = DoubleConv(feature_dims, feature_dims * 2)
        self.encoder3 = DoubleConv(feature_dims * 2, feature_dims * 4)
        self.encoder4 = DoubleConv(feature_dims * 4, feature_dims * 8)

        self.embed = nn.Embedding(feature_dims * 16,1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.pe_layer = PositionEmbeddingSine(feature_dims * 16 // 2, normalize=True)
        self.transformer = Transformer(
            d_model=feature_dims * 16,
            dropout=0.1,
            nhead=4,
            dim_feedforward=feature_dims * 4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )

        self.bottleneck_conv = DoubleConv(feature_dims * 8, feature_dims * 16)

        self.upconv4 = nn.ConvTranspose2d(feature_dims * 16, feature_dims * 8, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(feature_dims * 16, feature_dims * 8)
        self.upconv3 = nn.ConvTranspose2d(feature_dims * 8, feature_dims * 4, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(feature_dims * 8, feature_dims * 4)
        self.upconv2 = nn.ConvTranspose2d(feature_dims * 4, feature_dims * 2, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(feature_dims * 4, feature_dims * 2)
        self.upconv1 = nn.ConvTranspose2d(feature_dims * 2, feature_dims, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(feature_dims * 2, feature_dims)

        self.final_conv = nn.Conv2d(feature_dims, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.pool(enc4)
        bottleneck = self.bottleneck_conv(bottleneck)

        b, c, h, w = bottleneck.shape
        bottleneck_avg = bottleneck.mean(dim=(2, 3), keepdim=True).squeeze(-1)
        bottleneck_pe = self.pe_layer(bottleneck)

        transformer_output = self.transformer(bottleneck, None, bottleneck_avg, self.embed.weight, bottleneck_pe) # (bsz,1024,1)

        transformer_output_expanded = transformer_output.unsqueeze(-1) # (bsz,1024,1,1)
        bottleneck = bottleneck * transformer_output_expanded

        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.final_conv(dec1))

# Backbone

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet50, resnet101
from torchvision.models._utils import IntermediateLayerGetter

class FrozenBatchNorm2d(torch.nn.Module):
    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias


class BackboneBase(nn.Module):
    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, x):
        x = self.body(x)
        return x

resnets_dict = {
    'resnet50': resnet50,
    'resnet101': resnet101,
}

class Backbone(BackboneBase):
    def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: list):
        backbone = resnets_dict[name](
            pretrained=True, replace_stride_with_dilation=dilation, norm_layer=FrozenBatchNorm2d
        )
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

# MSR Classes

## New base ++

In [None]:
import torch
import torch.nn as nn
from timm import create_model
from torchvision.models import vit_b_16


class MSR(nn.Module):
    def __init__(self, layers=50, num_classes=2, reduce_dim=256):
        super(MSR, self).__init__()

        self.backbone = Backbone(
            'resnet{}'.format(layers),
            train_backbone=False,
            return_interm_layers=True,
            dilation=[False, True, True]
        )
        
        # Use Swin Transformer as the backbone
        self.backbone2 = create_model(
            'swin_base_patch4_window7_224', 
            pretrained=True, 
            features_only=False
        )

        # ViT Backbone
        self.backbone3 = create_model(
            'vit_base_patch16_224',  # Base ViT model
            pretrained=True,         # Use pre-trained weights
            img_size=512,            # Update input size to 512x512
            num_classes=num_classes  # Number of output classes
            
        )
        self.backbone3.head = nn.Identity()
        # self.backbone3 = vit_b_16(pretrained=True)

        self.embed_cat=nn.Embedding(reduce_dim,1)
        self.embed_3=nn.Embedding(reduce_dim,1)
        self.embed_2=nn.Embedding(reduce_dim//2,1)
        self.embed_1=nn.Embedding(reduce_dim//2,1)

        self.pe_layer_cat=PositionEmbeddingSine(reduce_dim//2, normalize=True)
        self.pe_layer_3=PositionEmbeddingSine(reduce_dim//2, normalize=True)
        self.pe_layer_2=PositionEmbeddingSine(reduce_dim//4, normalize=True)
        self.pe_layer_1=PositionEmbeddingSine(reduce_dim//4, normalize=True)

        self.transformer_cat = Transformer(
            d_model=reduce_dim,
            dropout=0.1,
            nhead=4,
            dim_feedforward=reduce_dim//4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )

        self.transformer_3 = Transformer(
            d_model=reduce_dim,
            dropout=0.1,
            nhead=4,
            dim_feedforward=reduce_dim//4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )
        self.transformer_2 = Transformer(
            d_model=reduce_dim//2,
            dropout=0.1,
            nhead=4,
            dim_feedforward=reduce_dim//4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )
        self.transformer_1 = Transformer(
            d_model=reduce_dim//2,
            dropout=0.1,
            nhead=4,
            dim_feedforward=reduce_dim//4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )


        # Dimensionality reduction
        self.conv_red_1 = nn.Sequential(
            nn.Conv2d(512, reduce_dim // 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(reduce_dim // 2),
            nn.ReLU(inplace=True)
        )
        self.conv_red_2 = nn.Sequential(
            nn.Conv2d(1024, reduce_dim // 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(reduce_dim // 2),
            nn.ReLU(inplace=True)
        )
        self.conv_red_3 = nn.Sequential(
            nn.Conv2d(2048, reduce_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(reduce_dim),
            nn.ReLU(inplace=True)
        )

        # Pooling and classification layers
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(reduce_dim * 2, reduce_dim)
        self.fc_2 = nn.Linear(reduce_dim, num_classes)
        self.fc_3 = nn.Linear(reduce_dim // 2, num_classes)
        # self.fc_4 = nn.Linear(1000, num_classes)
        # self.fc_5 = nn.Linear(768, num_classes)

        # Learnable dynamic weights for output fusion
        self.dynamic_weights = nn.Parameter(torch.tensor([1.0, 1.0, 1.0], requires_grad=True))

    def forward(self, x):
        # Backbone feature extraction
        # print("x shape", x.shape)
        
        back_x = self.backbone(x)
        # back_x2 = self.backbone2(x)
        # back_x3 = self.backbone3(x)
        
        # print("back_x3: ", back_x3.shape)
        
        # return back_x

        # print("backx 1 ", back_x['1'].shape)

        # Process each stage of Swin features
        red_back_1 = self.conv_red_1(back_x['1'])
        avg_back_1 = self.avgpool(red_back_1)
        red_back_2 = self.conv_red_2(back_x['2'])
        avg_back_2 = self.avgpool(red_back_2)
        red_back_3 = self.conv_red_3(back_x['3'])
        avg_back_3 = self.avgpool(red_back_3)

        red_back_cat = torch.cat((red_back_1, red_back_2), dim=1)
        avg_back_cat = self.avgpool(red_back_cat)


        # print("red_back_1 shape:",red_back_1.shape)
        # print("red_back_2 shape:",red_back_2.shape)
        # print("red_back_3 shape:",red_back_3.shape)
        # print("red back cat shape:",red_back_cat.shape)



        masking=None
        query_pos = self.embed_1.weight
        key_embed = red_back_1
        query_embed = avg_back_1.squeeze(-1)
        key_pos = self.pe_layer_1(red_back_1)

        fg_embed_1=self.transformer_1(key_embed,masking,query_embed,query_pos,key_pos)
        # print("shape1: ",fg_embed_1.shape)

        query_pos = self.embed_2.weight
        key_embed = red_back_2
        query_embed = avg_back_2.squeeze(-1)
        key_pos = self.pe_layer_2(red_back_1)

        fg_embed_2=self.transformer_2(key_embed,masking,query_embed,query_pos,key_pos)
        # print("shape2: ",fg_embed_2.shape)


        query_pos = self.embed_3.weight
        key_embed = red_back_3
        query_embed = avg_back_3.squeeze(-1)
        key_pos = self.pe_layer_3(red_back_3)
        fg_embed_3=self.transformer_3(key_embed,masking,query_embed,query_pos,key_pos)
        # print("shape3: ",fg_embed_3.shape)




        query_pos = self.embed_cat.weight
        key_embed = red_back_cat
        query_embed = avg_back_cat.squeeze(-1)
        key_pos = self.pe_layer_cat(red_back_cat)

        fg_embed_cat=self.transformer_cat(key_embed,masking,query_embed,query_pos,key_pos)
        # print("shape cat: ",fg_embed_cat.shape)




        out = torch.cat((fg_embed_cat, fg_embed_3), dim=1)
        # print("out: ",out.shape)


        out_1 = torch.flatten(out, 1)  # Flatten the feature maps
        # print("out1 flat: ",out_1.shape)

        out_1 = self.fc(out_1)  # Fully connected layer
        out_1 = self.fc_2(out_1)  # Fully connected layer
        # print("out1 fc: ", out_1.shape)
        out_2=torch.flatten(fg_embed_1,1)
        # print("out2 : ",out_2.shape)
        out_2=self.fc_3(out_2)

        out_3=torch.flatten(fg_embed_2,1)
        out_3=self.fc_3(out_3)
        # print("out1: ",out_1.shape)
        # print("out2: ",out_2.shape)
        # print("out3: ",out_3.shape)
        # print("output shape: ",back_x2.shape)
        # out_4 = self.fc_4(back_x2)
        # print("out4: ",out_4.shape)
        # out_5 = self.fc_5(back_x3)
        # print("out5: ",out_5.shape)
        dynamic_weights = torch.softmax(self.dynamic_weights, dim=0)

        # Weighted average of outputs
        final_out = (
            dynamic_weights[0] * out_1 +
            dynamic_weights[1] * out_2 +
            dynamic_weights[2] * out_3
            # dynamic_weights[3] * out_5 
            # dynamic_weights[4] * out_5
        )

        return final_out

    def get_backbone_params(self):
        return self.backbone.parameters()

    def get_fc_params(self):
        return self.fc.parameters()

## Old Base ++

In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16
from timm import create_model 


class MSR1(nn.Module):
    def __init__(self, layers, num_classes=2, reduce_dim=256):
        super(MSR1, self).__init__()
        self.backbone = Backbone(
            'resnet{}'.format(layers),
            train_backbone=False,
            return_interm_layers=True,
            dilation=[False, True, True]
        )

        # ViT Backbone
        # self.backbone3 = vit_b_16(pretrained=True)
        # self.backbone3.head = nn.Identity()  # Remove classification head
        self.backbone3 = create_model(
            'vit_base_patch16_224',  # Base ViT model
            pretrained=True,         # Use pre-trained weights
            img_size=512,            # Update input size to 512x512
            num_classes=num_classes  # Number of output classes
            
        )
        self.backbone3.head = nn.Identity()
        
        self.embed_cat=nn.Embedding(reduce_dim,1)
        self.embed_3=nn.Embedding(reduce_dim,1)
        self.pe_layer_cat=PositionEmbeddingSine(reduce_dim//2, normalize=True)
        self.pe_layer_3=PositionEmbeddingSine(reduce_dim//2, normalize=True)

        self.transformer_cat = Transformer(
            d_model=reduce_dim,
            dropout=0.1,
            nhead=4,
            dim_feedforward=reduce_dim//4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )

        self.transformer_3 = Transformer(
            d_model=reduce_dim,
            dropout=0.1,
            nhead=4,
            dim_feedforward=reduce_dim//4,
            num_encoder_layers=0,
            num_decoder_layers=1,
            normalize_before=False,
            return_intermediate_dec=False,
        )
        self.conv_red_1 = nn.Sequential(
            nn.Conv2d(512, reduce_dim//2, kernel_size=1, padding=0, bias=False)
        )
        self.conv_red_2 = nn.Sequential(
            nn.Conv2d(1024, reduce_dim//2, kernel_size=1, padding=0, bias=False)
        )
        self.conv_red_3 = nn.Sequential(
            nn.Conv2d(2048, reduce_dim, kernel_size=1, padding=0, bias=False)
        )


        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(reduce_dim*2, reduce_dim)  # Assuming last feature map size
        self.fc_2 = nn.Linear(reduce_dim, num_classes)  # Assuming last feature map size
        self.fc_3 = nn.Linear(768, num_classes)
        self.num_classes = num_classes

        # Learnable dynamic weights for output fusion
        self.dynamic_weights = nn.Parameter(torch.tensor([1.0, 1.0], requires_grad=True))

    def forward(self, x):
        # Backbone feature extraction
        back_x = self.backbone(x)
        # back_x2 = self.backbone3(x)
        red_back_1 = self.conv_red_1(back_x['1'])
        red_back_2 = self.conv_red_2(back_x['2'])

        red_back_cat = torch.cat((red_back_1, red_back_2), dim=1)
        avg_back_cat = self.avgpool(red_back_cat)

        red_back_3 = self.conv_red_3(back_x['3'])
        avg_back_3 = self.avgpool(red_back_3)

        masking=None
        query_pos = self.embed_cat.weight

        key_embed = red_back_cat
        query_embed = avg_back_cat.squeeze(-1)
        key_pos = self.pe_layer_cat(red_back_cat)

        fg_embed_cat=self.transformer_cat(key_embed,masking,query_embed,query_pos,key_pos)

        query_pos = self.embed_3.weight
        key_embed = red_back_3
        query_embed = avg_back_3.squeeze(-1)
        key_pos = self.pe_layer_3(red_back_3)
        fg_embed_3=self.transformer_3(key_embed,masking,query_embed,query_pos,key_pos)


        #out_back_cat = (torch.einsum("bchw,bcl->blhw",red_back_cat,fg_embed_cat)).permute(0, 2, 3, 1)
        #out_back_3 = (torch.einsum("bchw,bcl->blhw",red_back_3,fg_embed_3)).permute(0, 2, 3, 1)



        # Concatenate along the last dimension
        out = torch.cat((fg_embed_cat, fg_embed_3), dim=1)

        out_1 = torch.flatten(out, 1)  # Flatten the feature maps

        out_1 = self.fc(out_1)  # Fully connected layer
        out_1 = self.fc_2(out_1)  # Fully connected layer
        out_2 = self.fc_3(back_x2)

        dynamic_weights = torch.softmax(self.dynamic_weights, dim=0)

        # Weighted average of outputs
        final_out = (
            dynamic_weights[0] * out_1 +
            dynamic_weights[1] * out_2
        )

        return final_out

    def get_backbone_params(self):
        return self.backbone.parameters()

    def get_fc_params(self):
        return self.fc.parameters()

# Dataloader

In [None]:
img_res = 512

In [None]:
import os
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

def get_data_loader(data_dir, batch_size):
    # Define transformations
    train_transform = transforms.Compose([
        transforms.Resize((img_res, img_res)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor()
    ])

    test_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])

    # Load datasets
    train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, "train"), transform=train_transform)
    test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, "test"), transform=test_transform)



    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)



    return train_loader, test_loader

def get_test_loader(data_dir, batch_size):
    transform = transforms.Compose([
        transforms.Resize((img_res, img_res)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load test dataset
    test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, "test"), transform=transform)


    # Create test loader
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)



    return test_loader


In [None]:
!pip install segmentation-models-pytorch

# Generate Masks

In [None]:
import pickle
import os
from PIL import Image
import numpy as np
import torch
import shutil  # For directory management
from torchvision import transforms
import matplotlib.pyplot as plt

# Load the trained model from a .pkl file
model_path = "/kaggle/input/dhur/pytorch/default/1/best_model_base.pth"
model = UNet(in_channels=3, out_channels=3, feature_dims=16)  # Match your training parameters
model.load_state_dict(torch.load(model_path))
model.to(device)  # Move the model to the appropriate device (CPU or GPU)

print("Model loaded successfully from .pkl")

for name, param in model.named_parameters():
    print(f"{name}: mean={param.mean().item()}, std={param.std().item()}")

from pathlib import Path
Path('/kaggle/working/Kermany_masks/train').mkdir(parents=True, exist_ok=True)
Path('/kaggle/working/Kermany_masks/test').mkdir(parents=True, exist_ok=True)        



def generate_masks(model, data_loader, device, output_dir, threshold=0.5):

    # # Ensure the output directory is fresh each time
    # # if os.path.exists(output_dir):
    # #     shutil.rmtree(output_dir)
    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir)
    os.makedirs(os.path.join(output_dir, "NORMAL"), exist_ok=True)
    os.makedirs(os.path.join(output_dir, "PNEUMONIA"), exist_ok=True)

    model.eval()
    model.to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            # Convert inputs to grayscale if necessary
          #  inputs = transforms.functional.rgb_to_grayscale(inputs, num_output_channels=1)

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            predictions = (outputs > threshold).float()


            for j in range(inputs.size(0)):
                label = targets[j].item()  # Get the integer label (0 or 1)

                # Determine the save directory based on the label
                if label == 0:
                    class_dir = os.path.join(output_dir, "NORMAL")
                elif label == 1:
                    class_dir = os.path.join(output_dir, "PNEUMONIA")
                else:
                    raise ValueError(f"Unexpected label value: {label}")

                # Get the original image filename
                original_path, _ = data_loader.dataset.samples[batch_idx * data_loader.batch_size + j]
                image_name = os.path.basename(original_path)  # Extract the original filename
                #print("Image name", image_name)
                save_path = os.path.join(class_dir, image_name.replace(".jpg", "_mask.png"))

                # Save the predicted mask with the original filename
                pred_mask = predictions[j].cpu().numpy().squeeze()
                pred_mask = (pred_mask > threshold).astype(np.uint8)
                if pred_mask.ndim == 3:
                    pred_mask = pred_mask[0]
                 # Plot the mask for debugging
                # plt.figure(figsize=(10, 5))
                # plt.subplot(1, 2, 1)
                # plt.title("Predicted Mask")
                # plt.imshow(pred_mask, cmap='gray')
                # plt.axis("off")

                # plt.subplot(1, 2, 2)
                # plt.title("Original Image")
                # plt.imshow(inputs[j].cpu().permute(1, 2, 0).numpy())
                # plt.axis("off")

                # plt.show()


                pred_mask = (pred_mask * 255).astype(np.uint8)  # Scale to 0-255
                Image.fromarray(pred_mask).save(save_path)



# Example Usage
# data_dir = "./kermany/kermany"  # Ensure this is where your dataset is located
data_dir = "/kaggle/working/pediatric_pneumonia/Pediatric Chest X-ray Pneumonia"
batch_size = 8
#one for train one for test
output_dir = "/kaggle/working/Kermany_masks/test"
output_dir1 = "/kaggle/working/Kermany_masks/train"

# Assuming `segmentation_model` is your pre-trained model
train_loader, test_loader = get_data_loader(data_dir, batch_size)
generate_masks(model, test_loader, device, output_dir)
print("Test.")
generate_masks(model, train_loader, device, output_dir1)
print("Train")

print("Output generated.")

# Data Preprocessing 

In [None]:
from torch.utils.data import random_split

def get_combined_data_loaders(data_dir, mask_dir, batch_size):
    image_transform = transforms.Compose([
        transforms.Resize((img_res, img_res)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.Resize((img_res, img_res)),
        transforms.ToTensor()
    ])
    test_transform = transforms.Compose([
        transforms.Resize((img_res, img_res)),
        transforms.ToTensor()
    ])

    class CombinedDataset(torch.utils.data.Dataset):
        def __init__(self, data_dir, mask_dir, transform_img, transform_mask):
            self.image_dataset = datasets.ImageFolder(root=data_dir, transform=None)
            self.mask_dir = mask_dir
            self.transform_img = transform_img
            self.transform_mask = transform_mask

            self.samples = self.image_dataset.samples

        def __len__(self):
            return len(self.image_dataset)

        def __getitem__(self, index):
            img_path, label = self.samples[index]
            img = Image.open(img_path).convert('RGB')
            if self.transform_img:
                img = self.transform_img(img)

            # Match mask by exact filename
            class_name = os.path.basename(os.path.dirname(img_path))
            mask_path = os.path.join(self.mask_dir, class_name, os.path.basename(img_path))  # Use the exact name
            if not os.path.exists(mask_path):
                raise FileNotFoundError(f"Mask not found: {mask_path}")

            mask = Image.open(mask_path).convert('L')  # Grayscale for masks
            if self.transform_mask:
                mask = self.transform_mask(mask)
            if mask.shape[0] == 1:
                mask = mask.repeat(3, 1, 1)

            # Convert the image to 3 channels if it's not already
            if img.shape[0] == 1:
                img = img.repeat(3, 1, 1)
            mask_rgb = torch.stack([mask.squeeze(0)] * 3, dim=0)  # Convert single-channel mask to 3-channel
            img_with_mask = img * mask_rgb
            return mask, label

    # Define directories for train and test datasets
    train_data_dir = os.path.join(data_dir, "train")
    test_data_dir = os.path.join(data_dir, "test")
    train_mask_dir = os.path.join(mask_dir, "train")
    test_mask_dir = os.path.join(mask_dir, "test")

    # Instantiate train and test datasets
    train_dataset = CombinedDataset(train_data_dir, train_mask_dir, image_transform, mask_transform)
    test_dataset = CombinedDataset(test_data_dir, test_mask_dir, test_transform, mask_transform)

    # Split train dataset into train and validation sets
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

    # Create DataLoaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
    validation_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, validation_loader, test_loader

# Example usage
data_dir = "/kaggle/working/pediatric_pneumonia/Pediatric Chest X-ray Pneumonia"  # Path to images
mask_dir = "/kaggle/working/Kermany_masks"   # Path to masks
batch_size = 8

train_loader, validation_loader, test_loader = get_combined_data_loaders(data_dir, mask_dir, batch_size)

# Train 

In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Train and validate the model
best_val_acc = 0.0
best_val_f1 = 0.0  # Initialize best_val_f1 here
# Helper function to count parameters
def count_parameters(model, trainable_only=False):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

# Training function
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []  # Store all predictions
    all_targets = []  # Store all targets

    for batch_idx, (inputs, targets) in enumerate(train_loader):  # Modified to unpack three values
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)




        
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # Accumulate predictions and targets
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    return train_loss, train_acc, all_preds, all_targets  # Return 4 values


# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []  # Store all predictions
    all_targets = []  # Store all targets

    with torch.no_grad():
        for inputs, targets in val_loader:  # Modified to unpack three values
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)

            loss = criterion(outputs, targets)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Accumulate predictions and targets
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc, all_preds, all_targets  # Return 4 values

# # Define dataset location and parameters
# data_dir = "./Kermany_masks"  # Ensure this is where you download/extract your dataset
# batch_size = 32
data_dir = "/kaggle/working/pediatric_pneumonia/Pediatric Chest X-ray Pneumonia"  # Path to images
mask_dir = "/kaggle/working/Kermany_masks"   # Path to masks
batch_size = 64

train_loader,val_loader, test_loader = get_combined_data_loaders(data_dir, mask_dir, batch_size)

learning_rate = 5e-3
num_epochs = 5
log_dir = "./logs"

# # Create DataLoader
# train_loader, test_loader = get_data_loader(data_dir, batch_size)

# Instantiate the MSR model
model = MSR1(layers=50, num_classes=2)  # Adjust for 3 classes: COVID19, NORMAL, PNEUMONIA
# model = MSR(num_classes=2)  # Adjust for 3 classes: COVID19, NORMAL, PNEUMONIA

# Count parameters
total_params = count_parameters(model)
trainable_params = count_parameters(model, trainable_only=True)
print(f"\nBackbone # param.: {total_params}")
print(f"Learnable # param.: {trainable_params}\n")

# Setup device and move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([
    {'params': model.get_backbone_params(), 'lr': learning_rate},
    {'params': model.get_fc_params(), 'lr': learning_rate * 10},
    {'params': model.dynamic_weights, 'lr': 0.01},
])

# Train and validate the model
best_val_acc = 0.0
os.makedirs(log_dir, exist_ok=True)

from sklearn.metrics import precision_score, recall_score, f1_score

def calculate_metrics(y_true, y_pred, average='weighted'):
    precision = precision_score(y_true, y_pred, average=average)
    recall = recall_score(y_true, y_pred, average=average)
    f1 = f1_score(y_true, y_pred, average=average)
    return precision, recall, f1

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    start_time = time.time()

    # Training Phase
    train_loss, train_acc, train_preds, train_targets = train(model, train_loader, optimizer, criterion, device)
    train_precision, train_recall, train_f1 = calculate_metrics(train_targets, train_preds)

    # Validation Phase
    val_loss, val_acc, val_preds, val_targets = validate(model, val_loader, criterion, device)
    val_precision, val_recall, val_f1 = calculate_metrics(val_targets, val_preds)

    end_time = time.time()
    epoch_time = end_time - start_time

    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
    print(f"Train Precision: {train_precision:.2f}, Train Recall: {train_recall:.2f}, Train F1-Score: {train_f1:.2f}")

    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
    print(f"Validation Precision: {val_precision:.2f}, Validation Recall: {val_recall:.2f}, Validation F1-Score: {val_f1:.2f}")

    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds\n")

    # Save the best model based on F1-Score
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), os.path.join(log_dir, "best_model.pth"))
        print(f"Best model saved with F1-Score: {best_val_f1:.2f}")


In [None]:
def test(model, test_loader, criterion, device):
    """
    Test the model on the test dataset.

    Args:
        model: Trained model to evaluate.
        test_loader: DataLoader for the test dataset.
        criterion: Loss function used for evaluation.
        device: Device to perform the computations (CPU or GPU).

    Returns:
        test_loss: Average test loss.
        test_acc: Test accuracy as a percentage.
        all_preds: All predicted labels.
        all_targets: All ground truth labels.
    """
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []  # Store all predictions
    all_targets = []  # Store all targets

    with torch.no_grad():
        for inputs, targets in test_loader:  # Get inputs and targets from test loader
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)  # Forward pass

            loss = criterion(outputs, targets)  # Compute loss

            running_loss += loss.item()
            _, predicted = outputs.max(1)  # Get the predicted class
            total += targets.size(0)  # Total number of samples
            correct += predicted.eq(targets).sum().item()  # Count correct predictions

            # Accumulate predictions and targets
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    test_loss = running_loss / len(test_loader)  # Compute average loss
    test_acc = 100. * correct / total  # Compute accuracy

    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")

    return test_loss, test_acc, all_preds, all_targets  # Return test metrics and predictions

In [None]:
import gc
gc.collect()

torch.cuda.empty_cache()

In [None]:
# Assuming you have already loaded the model and test_loader
# model_path = "/kaggle/input/ugh/pytorch/default/1/best_model.pth"
# model = MSR(layers=50, num_classes=2)  # Adjust for 3 classes: COVID19, NORMAL, PNEUMONIA
model_path = os.path.join(log_dir, "best_model.pth")
model.load_state_dict(torch.load(model_path))
model.to(device)  # Move the model to the appropriate device (CPU or GPU)
#criterion = nn.CrossEntropyLoss()

test_loss, test_acc, all_preds, all_targets = test(model, test_loader, criterion, device)

# Print results
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

# If needed, you can use all_preds and all_targets for further analysis
