In [None]:
import warnings
import timm
import gc

from fastai.vision.all import *
from fastcore.parallel import *

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import numpy as np
from scipy.signal import butter, filtfilt
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm import tqdm,trange,tqdm_notebook
from multiprocessing import Pool

import matplotlib.pyplot as plt

## Create Fine Tuned`DataLoaders`

Instead of processing the parquet files, I'm now using images directly which I have saved here after preprocessing and converting into the spectrogram. 

In [None]:
class Upload_Dataset(Dataset):
    def __init__(self, csv_path, transform=None, root="..", min_total_labels=10):
        self.metadata = pd.read_csv(csv_path)
        self.root = root
        self.transform = transform
        self.label_map = {'Seizure': 0, 'GPD': 1, 'LRDA': 2, 'Other': 3, 'GRDA': 4, 'LPD': 5}
        self.min_total_labels = min_total_labels

        # Filter metadata based on minimum total labels
        self.metadata = self.metadata[self.metadata[['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']].sum(axis=1) >= self.min_total_labels]

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        subSet = self.metadata.iloc[idx]
        eeg_id = subSet['eeg_id']
        offset_seconds = int(subSet['eeg_label_offset_seconds'])
        image_name = f"{eeg_id}-{offset_seconds}.png"
        image_path = os.path.join(self.root, image_name)
        image = Image.open(image_path)

        image = self.transform(image)

        # Extract labels
        labels = np.array([subSet['seizure_vote'], subSet['lpd_vote'], subSet['gpd_vote'],
                           subSet['lrda_vote'], subSet['grda_vote'], subSet['other_vote']], dtype=float)
        total_labels = np.sum(labels)
        labels /= total_labels

        # Extract consensus
        consensus = self.label_map[subSet['expert_consensus']]

        return image, labels, consensus

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

batch_size = 64

train_path = '/kaggle/input/hms-datasplit/train_mine.csv'
test_path = '/kaggle/input/hms-datasplit/test_mine.csv'
val_path = '/kaggle/input/hms-datasplit/val_mine.csv'
root = '/kaggle/input/hms-superlets'
filter_method = 'db8' # Here we can change it to LPF10, LPF20 and db8 also
transform_method = 'SL' # Here we can change the tranform method to Superlet(SL) and MEL also.
pre_spec = os.path.join(root, f"/kaggle/input/hms-superlets/hms-harmful-brain-activity-classification-{filter_method}-{transform_method}/hms-harmful-brain-activity-classification-{filter_method}-{transform_method}")

train_dataset = Upload_Dataset(csv_path=train_path, transform=transform,root=os.path.join(pre_spec,f"train"))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)


test_dataset = Upload_Dataset(csv_path=test_path, transform=transform,root=os.path.join(pre_spec,f"test"))
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)


val_dataset = Upload_Dataset(csv_path=val_path, transform=transform,root=os.path.join(pre_spec,f"val"))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

## Load Pre-trained Weights and Fine tune

Here we have tried out 4 different Models which we have loaded using timm such as ConvNextv2_atto, EfficientNet, Swinv2_tiny and MaxViT. Similarly try out the models also with different methods and variantions in parameters on already pretrained weights and fine tune it...

In [None]:
import torch
import torch.nn as nn
import timm
checkpoint = torch.load("/kaggle/input/maxvit-tiny/2024-04-24 15_03_43.110590-db8-SL-0.6307523211319289.pt")

class ConvNext(nn.Module):
    def __init__(self, n_classes, num_input_channels=1, pretrained=False):
        super(ConvNext, self).__init__()

        self.model = timm.create_model("maxvit_tiny_tf_224.in1k", pretrained=False, in_chans=num_input_channels, num_classes=n_classes)
        if pretrained:
            state_dict = {k.replace('module.model.', ''): v for k, v in checkpoint['model_state_dict'].items()}
            self.model.load_state_dict(state_dict)
        
    def forward(self, x):
        x = self.model(x)
        return x

epoch = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    convnext_model = nn.DataParallel(ConvNext(6, 1, True)).to(device)
else:
    convnext_model = ConvNext(6, 1, True).to(device)

optimizer = torch.optim.Adam(convnext_model.parameters(), lr=1e-04)

loss_history = [[], []]
accuracy_history = [[], []]
acc_epoch_history = [[],[]]
loss_epoch_history = [[],[]]

## Model Training and Evaluation on Fine-tuned data

In [None]:
from datetime import datetime

current_dateTime = datetime.now()

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

precision_history = [[], []]
recall_history = [[], []]
f1_history = [[], []]

for e in trange(epoch):
    convnext_model.train()
    print(f"====================== EPOCH {e+1} ======================")
    print("Training.....")
    for i, (data, labels, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, labels, target = data.to(device), labels.to(device), target.to(device)
        output = convnext_model(data.float())
        labels = labels.float()
        loss = F.kl_div(F.log_softmax(output, dim=1), labels, reduction='batchmean')
        
        loss.backward()
        
        nn.utils.clip_grad_norm_(convnext_model.parameters(), 3)

        loss_history[0].append(loss.item())
        
        accuracy = (output.argmax(dim=1) == target).float().mean()
        
        accuracy_history[0].append(accuracy)
        
        optimizer.step()
        
        if(i%50 == 0):
            print(f"MINIBATCH {i+1}/{len(train_loader)} TRAIN LOSS : {loss_history[0][-1]}")
    
    print("Validation.....")
    convnext_model.eval()
    
    with torch.no_grad():
        true_positives = 0
        false_positives = 0
        false_negatives = 0
        for i, (data, labels, target) in enumerate(val_loader):
            data, labels, target = data.to(device), labels.to(device), target.to(device)
            output = convnext_model(data)
            labels = labels.float()
            loss = F.kl_div(F.log_softmax(output, dim=1), labels, reduction='batchmean')
            
            predicted_labels = output.argmax(dim=1)
            true_positives += ((predicted_labels == target) & (predicted_labels == 1)).sum().item()
            false_positives += ((predicted_labels == 1) & (target == 0)).sum().item()
            false_negatives += ((predicted_labels == 0) & (target == 1)).sum().item()

            accuracy = (output.argmax(dim=1) == target).float().mean()
            loss_history[1].append(loss.item())
            accuracy_history[1].append(accuracy)
        
        precision = precision_score(target.cpu(), predicted_labels.cpu(), average='weighted', zero_division=1)
        recall = recall_score(target.cpu(), predicted_labels.cpu(), average='weighted', zero_division=1)
        f1 = f1_score(target.cpu(), predicted_labels.cpu(), average='weighted', zero_division=1)

        precision_history[1].append(precision)
        recall_history[1].append(recall)
        f1_history[1].append(f1)
        
    acc_epoch_history[0].append(sum(accuracy_history[0][-1:-len(train_loader):-1])/len(train_loader))
    acc_epoch_history[1].append(sum(accuracy_history[1][-1:-len(val_loader):-1])/len(val_loader))
    
    loss_epoch_history[0].append(sum(loss_history[0][-1:-len(train_loader):-1])/len(train_loader))
    loss_epoch_history[1].append(sum(loss_history[1][-1:-len(val_loader):-1])/len(val_loader))
    
    print("====================================================")
    print(f"TRAIN ACC : {acc_epoch_history[0][-1]}  TRAIN LOSS : {loss_epoch_history[0][-1]}")
    print(f"VAL ACC : {acc_epoch_history[1][-1]}  VAL LOSS : {loss_epoch_history[1][-1]}")
    print(f"VAL PRECISION: {precision}  RECALL: {recall}  F1: {f1}")
    print("====================================================")
    
    torch.save({
            'epoch': e,
            'model_state_dict': convnext_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss_epoch_history[0][-1],
            'acc' : acc_epoch_history[0][-1]
            }, f'{current_dateTime}-{filter_method}-{transform_method}-{loss_epoch_history[1][-1]}.pt')

convnext_model.eval()

with torch.no_grad():
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    for i, (data, labels, target) in enumerate(test_loader):
        data, labels, target = data.to(device), labels.to(device), target.to(device)
        output = convnext_model(data)
        labels = labels.float()
        test_loss = F.kl_div(F.log_softmax(output, dim=1), labels, reduction='batchmean')

        predicted_labels = output.argmax(dim=1)
        true_positives += ((predicted_labels == target) & (predicted_labels == 1)).sum().item()
        false_positives += ((predicted_labels == 1) & (target == 0)).sum().item()
        false_negatives += ((predicted_labels == 0) & (target == 1)).sum().item()
        
        test_accuracy = (output.argmax(dim=1) == target).float().mean()

    precision = precision_score(target.cpu(), predicted_labels.cpu(), average='weighted', zero_division=1)
    recall = recall_score(target.cpu(), predicted_labels.cpu(), average='weighted', zero_division=1)
    f1 = f1_score(target.cpu(), predicted_labels.cpu(), average='weighted', zero_division=1)

print("====================================================")
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {precision}, Recall: {recall}, F1: {f1}")
print("====================================================")

## Plots of Accuracy and Loss vs Epochs

In [None]:
import numpy as np
import matplotlib.pyplot as plt

train_acc_np = np.array([item.cpu() for item in acc_epoch_history[0]])
val_acc_np = np.array([item.cpu() for item in acc_epoch_history[1]])

# Plot accuracy history
plt.figure(figsize=(10, 5))
plt.plot(train_acc_np, label='Train Accuracy')
plt.plot(val_acc_np, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy vs. Epoch')
plt.legend()
plt.grid(True)
plt.show()

# Plot loss history
plt.figure(figsize=(10, 5))
plt.plot(loss_epoch_history[0], label='Train Loss')
plt.plot(loss_epoch_history[1], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss vs. Epoch')
plt.legend()
plt.grid(True)
plt.show()