In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from timm.models.layers import DropPath, trunc_normal_

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=partial(nn.LayerNorm, eps=1e-6)):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class ConvBlock(nn.Module):

    def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
                 norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None):
        super(ConvBlock, self).__init__()

        expansion = 4
        med_planes = outplanes // expansion

        self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = norm_layer(med_planes)
        self.act1 = act_layer(inplace=True)

        self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False)
        self.bn2 = norm_layer(med_planes)
        self.act2 = act_layer(inplace=True)

        self.conv3 = nn.Conv2d(med_planes, outplanes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = norm_layer(outplanes)
        self.act3 = act_layer(inplace=True)

        if res_conv:
            self.residual_conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, padding=0, bias=False)
            self.residual_bn = norm_layer(outplanes)

        self.res_conv = res_conv
        self.drop_block = drop_block
        self.drop_path = drop_path

    def zero_init_last_bn(self):
        nn.init.zeros_(self.bn3.weight)

    def forward(self, x, x_t=None, return_x_2=True):
        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act1(x)

        x = self.conv2(x) if x_t is None else self.conv2(x + x_t)
        x = self.bn2(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x2 = self.act2(x)

        x = self.conv3(x2)
        x = self.bn3(x)
        if self.drop_block is not None:
            x = self.drop_block(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        if self.res_conv:
            residual = self.residual_conv(residual)
            residual = self.residual_bn(residual)

        x += residual
        x = self.act3(x)

        if return_x_2:
            return x, x2
        else:
            return x


class FCUDown(nn.Module):
    """ CNN feature maps -> Transformer patch embeddings
    """

    def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6)):
        super(FCUDown, self).__init__()
        self.dw_stride = dw_stride

        self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
        self.sample_pooling = nn.AvgPool2d(kernel_size=dw_stride, stride=dw_stride)

        self.ln = norm_layer(outplanes)
        self.act = act_layer()

    def forward(self, x, x_t):
        x = self.conv_project(x)  # [N, C, H, W]

        x = self.sample_pooling(x).flatten(2).transpose(1, 2)
        x = self.ln(x)
        x = self.act(x)

        x = torch.cat([x_t[:, 0][:, None, :], x], dim=1)

        return x


class FCUUp(nn.Module):
    """ Transformer patch embeddings -> CNN feature maps
    """

    def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
                 norm_layer=partial(nn.BatchNorm2d, eps=1e-6),):
        super(FCUUp, self).__init__()

        self.up_stride = up_stride
        self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)
        self.bn = norm_layer(outplanes)
        self.act = act_layer()

    def forward(self, x, H, W):
        B, _, C = x.shape
        # [N, 197, 384] -> [N, 196, 384] -> [N, 384, 196] -> [N, 384, 14, 14]
        x_r = x[:, 1:].transpose(1, 2).reshape(B, C, H, W)
        x_r = self.act(self.bn(self.conv_project(x_r)))

        return F.interpolate(x_r, size=(H * self.up_stride, W * self.up_stride))


class Med_ConvBlock(nn.Module):
    """ special case for Convblock with down sampling,
    """
    def __init__(self, inplanes, act_layer=nn.ReLU, groups=1, norm_layer=partial(nn.BatchNorm2d, eps=1e-6),
                 drop_block=None, drop_path=None):

        super(Med_ConvBlock, self).__init__()

        expansion = 4
        med_planes = inplanes // expansion

        self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = norm_layer(med_planes)
        self.act1 = act_layer(inplace=True)

        self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=1, groups=groups, padding=1, bias=False)
        self.bn2 = norm_layer(med_planes)
        self.act2 = act_layer(inplace=True)

        self.conv3 = nn.Conv2d(med_planes, inplanes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = norm_layer(inplanes)
        self.act3 = act_layer(inplace=True)

        self.drop_block = drop_block
        self.drop_path = drop_path

    def zero_init_last_bn(self):
        nn.init.zeros_(self.bn3.weight)

    def forward(self, x):
        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        if self.drop_block is not None:
            x = self.drop_block(x)
        x = self.act2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        if self.drop_block is not None:
            x = self.drop_block(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        x += residual
        x = self.act3(x)

        return x


class ConvTransBlock(nn.Module):
    """
    Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
    """

    def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 last_fusion=False, num_med_block=0, groups=1):

        super(ConvTransBlock, self).__init__()
        expansion = 4
        self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=stride, groups=groups)

        if last_fusion:
            self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=2, res_conv=True, groups=groups)
        else:
            self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, groups=groups)

        if num_med_block > 0:
            self.med_block = []
            for i in range(num_med_block):
                self.med_block.append(Med_ConvBlock(inplanes=outplanes, groups=groups))
            self.med_block = nn.ModuleList(self.med_block)

        self.squeeze_block = FCUDown(inplanes=outplanes // expansion, outplanes=embed_dim, dw_stride=dw_stride)

        self.expand_block = FCUUp(inplanes=embed_dim, outplanes=outplanes // expansion, up_stride=dw_stride)

        self.trans_block = Block(
            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate)

        self.dw_stride = dw_stride
        self.embed_dim = embed_dim
        self.num_med_block = num_med_block
        self.last_fusion = last_fusion

    def forward(self, x, x_t):
        x, x2 = self.cnn_block(x)

        _, _, H, W = x2.shape

        x_st = self.squeeze_block(x2, x_t)

        x_t = self.trans_block(x_st + x_t)

        if self.num_med_block > 0:
            for m in self.med_block:
                x = m(x)

        x_t_r = self.expand_block(x_t, H // self.dw_stride, W // self.dw_stride)
        x = self.fusion_block(x, x_t_r, return_x_2=False)

        return x, x_t


class Conformer(nn.Module):

    def __init__(self, patch_size=16, in_chans=3, num_classes=1000, base_channel=64, channel_ratio=4, num_med_block=0,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):

        # Transformer
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        assert depth % 3 == 0

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        # Classifier head
        self.trans_norm = nn.LayerNorm(embed_dim)
        self.trans_cls_head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)

        # Stem stage: get the feature maps by conv block (copied form resnet.py)
        self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 1 / 2 [112, 112]
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 1 / 4 [56, 56]

        # 1 stage
        stage_1_channel = int(base_channel * channel_ratio)
        trans_dw_stride = patch_size // 4
        self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1)
        self.trans_patch_conv = nn.Conv2d(64, embed_dim, kernel_size=trans_dw_stride, stride=trans_dw_stride, padding=0)
        self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                             qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0],
                             )

        # 2~4 stage
        init_stage = 2
        fin_stage = depth // 3 + 1
        for i in range(init_stage, fin_stage):
            self.add_module('conv_trans_' + str(i),
                    ConvTransBlock(
                        stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride, embed_dim=embed_dim,
                        num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                        drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i-1],
                        num_med_block=num_med_block
                    )
            )


        stage_2_channel = int(base_channel * channel_ratio * 2)
        # 5~8 stage
        init_stage = fin_stage # 5
        fin_stage = fin_stage + depth // 3 # 9
        for i in range(init_stage, fin_stage):
            s = 2 if i == init_stage else 1
            in_channel = stage_1_channel if i == init_stage else stage_2_channel
            res_conv = True if i == init_stage else False
            self.add_module('conv_trans_' + str(i),
                    ConvTransBlock(
                        in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2, embed_dim=embed_dim,
                        num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                        drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i-1],
                        num_med_block=num_med_block
                    )
            )

        stage_3_channel = int(base_channel * channel_ratio * 2 * 2)
        # 9~12 stage
        init_stage = fin_stage  # 9
        fin_stage = fin_stage + depth // 3  # 13
        for i in range(init_stage, fin_stage):
            s = 2 if i == init_stage else 1
            in_channel = stage_2_channel if i == init_stage else stage_3_channel
            res_conv = True if i == init_stage else False
            last_fusion = True if i == depth else False
            self.add_module('conv_trans_' + str(i),
                    ConvTransBlock(
                        in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4, embed_dim=embed_dim,
                        num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                        drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i-1],
                        num_med_block=num_med_block, last_fusion=last_fusion
                    )
            )
        self.fin_stage = fin_stage

        trunc_normal_(self.cls_token, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1.)
            nn.init.constant_(m.bias, 0.)
        elif isinstance(m, nn.GroupNorm):
            nn.init.constant_(m.weight, 1.)
            nn.init.constant_(m.bias, 0.)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}


    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        # pdb.set_trace()
        # stem stage [N, 3, 224, 224] -> [N, 64, 56, 56]
        x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))

        # 1 stage
        x = self.conv_1(x_base, return_x_2=False)

        x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
        x_t = torch.cat([cls_tokens, x_t], dim=1)
        x_t = self.trans_1(x_t)

        # 2 ~ final
        for i in range(2, self.fin_stage):
            x, x_t = eval('self.conv_trans_' + str(i))(x, x_t)

        # conv classification
        x_p = self.pooling(x).flatten(1)
        conv_cls = self.conv_cls_head(x_p)

        # trans classification
        x_t = self.trans_norm(x_t)
        tran_cls = self.trans_cls_head(x_t[:, 0])

        return [conv_cls, tran_cls]



In [3]:
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.7.0-py3-none-any.whl size=114250 sha256=a21739ba37dc0d2434a9f0ac64fcd0cdbb3f87b7cf581c657529872b36451e2c
  Stored in directory: /root/.cache/pip/wheels/19/39/2f/2d3cadc408a8804103f1c34ddd4b9f6a93497b11fa96fe738e
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.7.0 medmnist-3.0.2


In [4]:
from medmnist import BloodMNIST
from torch.utils.data import DataLoader
from torchvision import transforms
# Define transforms
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(
            degrees=0,
            translate=(0.1, 0.1),
            scale=(0.9, 1.1),
            shear=10
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )  # Conformer expects 224x224 input
])

# Use the same transforms for validation and test
ot_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
              mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225]
          )
])

# Load datasets
train_dataset = BloodMNIST(split="train", download=True, transform=train_transforms, size=224)
val_dataset = BloodMNIST(split="val", transform=ot_transforms, download=True, size=224)
test_dataset = BloodMNIST(split="test", transform=ot_transforms, download=True, size=224)

# Create data loaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

image, label = train_dataset[0]

# Check the image shape
print(f"Image shape: {image.shape}, Label: {label}")

Downloading https://zenodo.org/records/10519652/files/bloodmnist_224.npz?download=1 to /root/.medmnist/bloodmnist_224.npz


100%|██████████| 1540731655/1540731655 [01:22<00:00, 18762452.58it/s]


Using downloaded and verified file: /root/.medmnist/bloodmnist_224.npz
Using downloaded and verified file: /root/.medmnist/bloodmnist_224.npz
Image shape: torch.Size([3, 224, 224]), Label: [7]


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim


import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
model = Conformer(
    patch_size=16,
    in_chans=3,
    num_classes=8,
    embed_dim=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4,
    qkv_bias=True,
    num_med_block=2
)
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',          # Since we're monitoring validation accuracy
    factor=0.5,         # Multiply LR by this factor when plateauing
    patience=2,         # Number of epochs to wait before reducing LR
    verbose=True,       # Print message when LR is reduced
    min_lr=1e-6        # Lower bound on the learning rate
)

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc='Training')

    for inputs, targets in pbar:
        # inputs = inputs.repeat(1, 3, 1, 1)  # Repeat grayscale image to 3 channels
        inputs, targets = inputs.to(device), targets.squeeze().to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        # Conformer returns [conv_cls, tran_cls]
        loss = criterion(outputs[0], targets) + criterion(outputs[1], targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})

    return running_loss / len(train_loader)

def evaluate(model, data_loader, device):
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for inputs, targets in data_loader:
            # inputs = inputs.repeat(1, 3, 1, 1)
            inputs, targets = inputs.to(device), targets.squeeze().to(device)

            outputs = model(inputs)
            # Average the predictions from both heads
            probs = (torch.softmax(outputs[0], dim=1) + torch.softmax(outputs[1], dim=1)) / 2
            preds = torch.argmax(probs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    return np.array(all_preds), np.array(all_targets), np.array(all_probs)

def plot_metrics(all_preds, all_targets, all_probs, save_prefix='test'):
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds, average='weighted')
    recall = recall_score(all_targets, all_preds, average='weighted')
    f1 = f1_score(all_targets, all_preds, average='weighted')

    # Calculate ROC curve and AUC (for multi-class)
    n_classes = all_probs.shape[1]
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve((all_targets == i).astype(int), all_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plot ROC curve
    plt.figure(figsize=(10, 8))
    for i in range(n_classes):
        plt.plot(fpr[i], tpr[i], label=f'Class {i} (AUC = {roc_auc[i]:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.savefig(f'{save_prefix}_roc_curve.png')
    plt.close()

    # Plot confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(f'{save_prefix}_confusion_matrix.png')
    plt.close()

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': roc_auc
    }

# Training loop
num_epochs = 100
best_val_acc = 0
patience = 10
patience_counter = 0

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')

    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_preds, val_targets, val_probs = evaluate(model, val_loader, device)
    val_metrics = plot_metrics(val_preds, val_targets, val_probs, save_prefix=f'val_epoch_{epoch}')

    print(f'Train Loss: {train_loss:.4f}')
    print(f'Validation Metrics:')
    print(f'Accuracy: {val_metrics["accuracy"]:.4f}')
    print(f'Precision: {val_metrics["precision"]:.4f}')
    print(f'Recall: {val_metrics["recall"]:.4f}')
    print(f'F1 Score: {val_metrics["f1"]:.4f}')

    scheduler.step(val_metrics['accuracy'])

    # Early stopping
    if val_metrics['accuracy'] > best_val_acc:
        best_val_acc = val_metrics['accuracy']
        torch.save(model.state_dict(), 'best_model.pth')
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print('Early stopping triggered')
            break

# Load best model and evaluate on test set
model.load_state_dict(torch.load('best_model.pth'))
test_preds, test_targets, test_probs = evaluate(model, test_loader, device)
test_metrics = plot_metrics(test_preds, test_targets, test_probs, save_prefix='test_final')

print('\nTest Metrics:')
print(f'Accuracy: {test_metrics["accuracy"]:.4f}')
print(f'Precision: {test_metrics["precision"]:.4f}')
print(f'Recall: {test_metrics["recall"]:.4f}')
print(f'F1 Score: {test_metrics["f1"]:.4f}')
print('AUC scores:', {f'Class {k}': v for k, v in test_metrics['auc'].items()})



Number of parameters: 47875088

Epoch 1/100


Training: 100%|██████████| 748/748 [08:25<00:00,  1.48it/s, loss=0.896]
  _warn_prf(average, modifier, msg_start, len(result))


Train Loss: 2.3817
Validation Metrics:
Accuracy: 0.6565
Precision: 0.6727
Recall: 0.6565
F1 Score: 0.6283

Epoch 2/100


Training: 100%|██████████| 748/748 [08:31<00:00,  1.46it/s, loss=2.63] 


Train Loss: 1.5218
Validation Metrics:
Accuracy: 0.6758
Precision: 0.8077
Recall: 0.6758
F1 Score: 0.6459

Epoch 3/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.568]


Train Loss: 1.1909
Validation Metrics:
Accuracy: 0.8715
Precision: 0.8873
Recall: 0.8715
F1 Score: 0.8681

Epoch 4/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.716]


Train Loss: 1.0554
Validation Metrics:
Accuracy: 0.8680
Precision: 0.8754
Recall: 0.8680
F1 Score: 0.8645

Epoch 5/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.157]


Train Loss: 0.8854
Validation Metrics:
Accuracy: 0.7862
Precision: 0.8465
Recall: 0.7862
F1 Score: 0.7853

Epoch 6/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=2.36]  


Train Loss: 0.7213
Validation Metrics:
Accuracy: 0.9398
Precision: 0.9404
Recall: 0.9398
F1 Score: 0.9395

Epoch 7/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=1.32]  


Train Loss: 0.6854
Validation Metrics:
Accuracy: 0.9060
Precision: 0.9179
Recall: 0.9060
F1 Score: 0.9029

Epoch 8/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.999] 


Train Loss: 0.6309
Validation Metrics:
Accuracy: 0.9498
Precision: 0.9514
Recall: 0.9498
F1 Score: 0.9499

Epoch 9/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=1]     


Train Loss: 0.6285
Validation Metrics:
Accuracy: 0.9334
Precision: 0.9365
Recall: 0.9334
F1 Score: 0.9329

Epoch 10/100


Training: 100%|██████████| 748/748 [08:27<00:00,  1.47it/s, loss=0.447] 


Train Loss: 0.5582
Validation Metrics:
Accuracy: 0.9609
Precision: 0.9622
Recall: 0.9609
F1 Score: 0.9606

Epoch 11/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.371] 


Train Loss: 0.7894
Validation Metrics:
Accuracy: 0.9007
Precision: 0.9145
Recall: 0.9007
F1 Score: 0.9004

Epoch 12/100


Training: 100%|██████████| 748/748 [08:27<00:00,  1.47it/s, loss=0.407] 


Train Loss: 0.7883
Validation Metrics:
Accuracy: 0.9317
Precision: 0.9360
Recall: 0.9317
F1 Score: 0.9302

Epoch 13/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.745] 


Train Loss: 0.4948
Validation Metrics:
Accuracy: 0.9574
Precision: 0.9583
Recall: 0.9574
F1 Score: 0.9573

Epoch 14/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.694]  


Train Loss: 0.3425
Validation Metrics:
Accuracy: 0.9650
Precision: 0.9656
Recall: 0.9650
F1 Score: 0.9645

Epoch 15/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.11]   


Train Loss: 0.3264
Validation Metrics:
Accuracy: 0.9486
Precision: 0.9504
Recall: 0.9486
F1 Score: 0.9483

Epoch 16/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.357] 


Train Loss: 0.3210
Validation Metrics:
Accuracy: 0.9591
Precision: 0.9601
Recall: 0.9591
F1 Score: 0.9586

Epoch 17/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.0525]


Train Loss: 0.3093
Validation Metrics:
Accuracy: 0.9614
Precision: 0.9636
Recall: 0.9614
F1 Score: 0.9615

Epoch 18/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.119]  


Train Loss: 0.2446
Validation Metrics:
Accuracy: 0.9626
Precision: 0.9635
Recall: 0.9626
F1 Score: 0.9623

Epoch 19/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.0705] 


Train Loss: 0.2311
Validation Metrics:
Accuracy: 0.9562
Precision: 0.9576
Recall: 0.9562
F1 Score: 0.9558

Epoch 20/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.749]  


Train Loss: 0.2185
Validation Metrics:
Accuracy: 0.9749
Precision: 0.9755
Recall: 0.9749
F1 Score: 0.9748

Epoch 21/100


Training: 100%|██████████| 748/748 [08:28<00:00,  1.47it/s, loss=0.208]  


Train Loss: 0.2192
Validation Metrics:
Accuracy: 0.9720
Precision: 0.9723
Recall: 0.9720
F1 Score: 0.9718

Epoch 22/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.376]  


Train Loss: 0.2147
Validation Metrics:
Accuracy: 0.9749
Precision: 0.9750
Recall: 0.9749
F1 Score: 0.9748

Epoch 23/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.0123] 


Train Loss: 0.2006
Validation Metrics:
Accuracy: 0.9749
Precision: 0.9753
Recall: 0.9749
F1 Score: 0.9748

Epoch 24/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.148]  


Train Loss: 0.1765
Validation Metrics:
Accuracy: 0.9708
Precision: 0.9716
Recall: 0.9708
F1 Score: 0.9706

Epoch 25/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.727]  


Train Loss: 0.1616
Validation Metrics:
Accuracy: 0.9667
Precision: 0.9685
Recall: 0.9667
F1 Score: 0.9666

Epoch 26/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.404]  


Train Loss: 0.1545
Validation Metrics:
Accuracy: 0.9761
Precision: 0.9767
Recall: 0.9761
F1 Score: 0.9760

Epoch 27/100


Training: 100%|██████████| 748/748 [08:31<00:00,  1.46it/s, loss=0.0134] 


Train Loss: 0.1508
Validation Metrics:
Accuracy: 0.9825
Precision: 0.9828
Recall: 0.9825
F1 Score: 0.9824

Epoch 28/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.00727]


Train Loss: 0.1549
Validation Metrics:
Accuracy: 0.9778
Precision: 0.9781
Recall: 0.9778
F1 Score: 0.9777

Epoch 29/100


Training: 100%|██████████| 748/748 [08:30<00:00,  1.47it/s, loss=0.0219] 


Train Loss: 0.1390
Validation Metrics:
Accuracy: 0.9801
Precision: 0.9805
Recall: 0.9801
F1 Score: 0.9800

Epoch 30/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.765]  


Train Loss: 0.1452
Validation Metrics:
Accuracy: 0.9702
Precision: 0.9713
Recall: 0.9702
F1 Score: 0.9700

Epoch 31/100


Training: 100%|██████████| 748/748 [08:30<00:00,  1.47it/s, loss=0.126]  


Train Loss: 0.1322
Validation Metrics:
Accuracy: 0.9720
Precision: 0.9728
Recall: 0.9720
F1 Score: 0.9719

Epoch 32/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.00382]


Train Loss: 0.1307
Validation Metrics:
Accuracy: 0.9784
Precision: 0.9788
Recall: 0.9784
F1 Score: 0.9782

Epoch 33/100


Training: 100%|██████████| 748/748 [08:32<00:00,  1.46it/s, loss=0.144]  


Train Loss: 0.1173
Validation Metrics:
Accuracy: 0.9731
Precision: 0.9742
Recall: 0.9731
F1 Score: 0.9731

Epoch 34/100


Training: 100%|██████████| 748/748 [08:30<00:00,  1.47it/s, loss=0.00286]


Train Loss: 0.1047
Validation Metrics:
Accuracy: 0.9755
Precision: 0.9759
Recall: 0.9755
F1 Score: 0.9753

Epoch 35/100


Training: 100%|██████████| 748/748 [08:29<00:00,  1.47it/s, loss=0.0095] 


Train Loss: 0.1076
Validation Metrics:
Accuracy: 0.9725
Precision: 0.9737
Recall: 0.9725
F1 Score: 0.9724

Epoch 36/100


Training: 100%|██████████| 748/748 [08:30<00:00,  1.47it/s, loss=0.0589] 


Train Loss: 0.1066
Validation Metrics:
Accuracy: 0.9737
Precision: 0.9744
Recall: 0.9737
F1 Score: 0.9736

Epoch 37/100


Training: 100%|██████████| 748/748 [08:31<00:00,  1.46it/s, loss=0.0257] 


Train Loss: 0.1018
Validation Metrics:
Accuracy: 0.9807
Precision: 0.9810
Recall: 0.9807
F1 Score: 0.9806
Early stopping triggered


  model.load_state_dict(torch.load('best_model.pth'))



Test Metrics:
Accuracy: 0.9798
Precision: 0.9800
Recall: 0.9798
F1 Score: 0.9798
AUC scores: {'Class 0': 0.9998451988420873, 'Class 1': 0.9997255530192606, 'Class 2': 0.9997756433452921, 'Class 3': 0.9980516773441555, 'Class 4': 0.9992101044475005, 'Class 5': 0.9986799983836714, 'Class 6': 0.9980728459857316, 'Class 7': 0.9999992790038718}
