In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel
from torch.nn.functional import softmax, one_hot
import pandas as pd 
import os
from tqdm import tqdm
from torchvision.transforms import transforms
import matplotlib.cm as cm
from PIL import Image
import numpy as np
import wandb

In [2]:
df = pd.read_csv("../train.csv")

In [3]:
df.head()

Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,1628180742,0,0.0,353733,0,0.0,127492639,42516,Seizure,3,0,0,0,0,0
1,1628180742,1,6.0,353733,1,6.0,3887563113,42516,Seizure,3,0,0,0,0,0
2,1628180742,2,8.0,353733,2,8.0,1142670488,42516,Seizure,3,0,0,0,0,0
3,1628180742,3,18.0,353733,3,18.0,2718991173,42516,Seizure,3,0,0,0,0,0
4,1628180742,4,24.0,353733,4,24.0,3080632009,42516,Seizure,3,0,0,0,0,0


In [4]:
classes = df['expert_consensus'].unique()

In [5]:
mapping = {
    c:i for i, c in enumerate(classes)
}

In [6]:
num_classes = classes.shape[0]

In [7]:
cmap = cm.get_cmap("viridis")

  cmap = cm.get_cmap("viridis")


In [8]:
class spectrogramDataset(Dataset):
    def __init__(self,csv_file,path="../train_spectrograms/",transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.path = path
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self,idx):
        spec_id = df.iloc[idx,:]['spectrogram_id']
        label = df.iloc[idx,:]['expert_consensus']
        
        spectrogram_path = os.path.join(self.path,str(spec_id)+".parquet")
        # print(spectrogram_path)
        spectrogram = pd.read_parquet(spectrogram_path).drop('time',axis=1).values
        # print(spectrogram)
        spectrogram = Image.fromarray((cmap(spectrogram) * 255).astype(np.uint8))
        
        if self.transform:
            spectrogram = self.transform(spectrogram)[:3, :, :]
        
        
        return spectrogram,mapping[label]
        
        

In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
#     transforms.CenterCrop(224),  # Center crop to maintain aspect ratio
    transforms.ToTensor(),  # Convert to PyTorch tensor
#     transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.406], std=[0.229, 0.224, 0.225, 0.225])  # Normalize (optional)
])

In [10]:
dataset = spectrogramDataset("../train.csv",transform=transform)

In [11]:
dataset.__len__()

106800

In [12]:
dataset.__getitem__(10000)

(tensor([[[0.9922, 0.9922, 0.9922,  ..., 0.1804, 0.1647, 0.1686],
          [0.9922, 0.9922, 0.9922,  ..., 0.1490, 0.1804, 0.2196],
          [0.9922, 0.9922, 0.9922,  ..., 0.1765, 0.1922, 0.2235],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.9059, 0.9059, 0.9059,  ..., 0.5059, 0.4980, 0.4706],
          [0.9059, 0.9059, 0.9059,  ..., 0.5176, 0.4314, 0.3451],
          [0.9059, 0.9059, 0.9059,  ..., 0.4471, 0.4078, 0.3294],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.1412, 0.1412, 0.1412,  ..., 0.5373, 0.5451, 0.5490],
          [0.1412, 0.1412, 0.1412,  ..., 0.5451, 0.5490, 0.5490],
          [0.1412, 0.1412, 0.1412,  ...,

In [13]:
train_len = int(.8*len(dataset))
val_len = len(dataset)-train_len

In [14]:
val_len

21360

In [15]:
train_set,val_set = torch.utils.data.random_split(dataset,[train_len,val_len])

In [16]:
train_loader = DataLoader(train_set,batch_size=32,shuffle=True)
val_loader = DataLoader(val_set, batch_size=32,shuffle=True)

In [17]:
class ViTClassifier(torch.nn.Module):
    def __init__(self,num_classes=5):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.classifer = torch.nn.Linear(self.vit.config.hidden_size,num_classes)
        
    def forward(self,images):
        output = self.vit(images)
        output = self.classifier(output.last_hidden_state[:,0])
        output = softmax(output,dim=1)
        return output

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
model = ViTClassifier(num_classes=num_classes).to(device)

In [20]:
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

In [21]:
args = {
    "lr":.001,
    "loss":"KLDivloss"
}

In [22]:
wandb.init(config=args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mjatinsingh[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115241644438357, max=1.0…

In [23]:
wandb.watch(model, log_freq=1)

[]

In [24]:
num_epochs = 10

for epoch in range(num_epochs):
    pbar = tqdm(train_loader)
    
    for spectrograms, labels in pbar:
        labels_onehot = one_hot(labels,num_classes=num_classes).float().to(device)
        spectrograms = spectrograms.to(device)
        optimizer.zero_grad()
        
        outputs = model(spectrograms)
        loss = loss_fn(outputs.log(),labels_onehot)
        loss.backward()
        optimizer.step()
        pbar.set_description(f'train loss: {loss}')
        wandb.log({"train loss":loss})
        
    with torch.no_grad():
        correct = total=val_loss = 0
        pbar = tqdm(val_loader)
        
        for spectrograms,labels in pbar:
            labels = labels.to(device)
            labels_onehot = one_hot(labels,num_classes).float()
            spectrograms = spectrograms.to(device)
            outputs = model(spectrograms)
            
            _, predicted = torch.max(outputs.data,1)
            
            total+=labels.size(0)
            
            correct += (predicted==labels).sum().item()
            loss = loss_fn(outputs.logs(),labels_onehot).item()
            val_loss+=loss
            pbar.set_description(f'val loss: {loss}')
            wandb.log({"val loss":loss})
            
        
        accuracy = 100 * correct/total
        wandb.log()
        
        
        kl_divergence = val_loss / len(val_loader)
        
        wandb.log({"accuracy":accuracy,"kl divergence":kl_divergence})
        
        
        print("Validation Accuracy: ", accuracy)
        print("KL divergence loss: ",kl_divergence )
        
torch.save(model.state_dict(),'trained_model.pt')        

  0%|          | 0/2670 [00:00<?, ?it/s]

: 