In [1]:
pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0
Note: you may need to restart the kernel to use updated packages.


In [30]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from sklearn.model_selection import train_test_split
import torch
import pydicom
from torch.utils.data import Dataset, DataLoader

# Transformations
from scipy.ndimage import zoom 
from monai.transforms import (
   Compose,
   ScaleIntensityd,
   NormalizeIntensityd,
   ScaleIntensityRanged,
   RandFlipd,
   RandRotate90d,
   RandShiftIntensityd
)

# Model
from monai.networks.nets import DenseNet121, DenseNet169, EfficientNetBN
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm

In [5]:
source_dir = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train'
annotations_file = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv'

In [6]:
annotations_df = pd.read_csv(annotations_file,
                            dtype = {"BraTS21ID": str})
annotations_df.head()

Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1


## Define the problematic cases

In [7]:
problematic_cases = ['00109', '00123', '00709']

## Split Data by Patient Cases

In [8]:
def get_slice_paths(patient_modality_path):
    all_slices = []
    for slice_img in os.listdir(patient_modality_path):
        if slice_img.endswith('.dcm'):
            all_slices.append(os.path.join(patient_modality_path, slice_img))
    return sorted(all_slices)

In [9]:
def split_data(source_dir, problematic_cases, train_size, test_size, val_size, modality):
        # Get all patient cases
    patient_cases = [ patient_id for patient_id in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, patient_id)) ]
    
    # Filter out the problematic cases
    patient_cases = [case for case in patient_cases if case not in problematic_cases ]
    
    train_val_cases, test_cases = train_test_split(
        patient_cases,
        test_size = test_size,
        random_state = 42
    )
    
    val_size_adjusted = val_size/(train_size + val_size)
    
    train_cases, val_cases = train_test_split(
        train_val_cases,
        test_size = val_size_adjusted,
        random_state =42    )
    return {
        'train':{
            'patient_ids': train_cases,
            'paths': {patient_id: os.path.join(source_dir, patient_id) for patient_id in train_cases},
            'slices': {patient_id: get_slice_paths(os.path.join(source_dir, patient_id, modality)) for patient_id in train_cases}
        },
        'val':{
            'patient_ids': val_cases,
            'paths': {patient_id: os.path.join(source_dir, patient_id) for patient_id in val_cases},
            'slices': {patient_id: get_slice_paths(os.path.join(source_dir, patient_id, modality)) for patient_id in val_cases}     
        
        },
        'test':{
            'patient_ids': test_cases,
            'paths': {patient_id: os.path.join(source_dir, patient_id) for patient_id in test_cases},
            'slices': {patient_id: get_slice_paths(os.path.join(source_dir, patient_id, modality)) for patient_id in test_cases}     
        
        }
    }

In [10]:
# Split data
splits = split_data(
    source_dir=source_dir,
    train_size=0.7,
    test_size=0.15,
    val_size=0.15,
    problematic_cases = problematic_cases,
    modality = 'FLAIR')

In [11]:
# splits['train']['slices']

## Define the Custom BraTs Dataset

In [12]:
class BraTsDataset(Dataset):

    def __init__(self, data_dict, annotations_df, transforms = None, cache_size = 0):
        self.patient_ids = data_dict['patient_ids']
        self.slice_paths = data_dict['slices']
        self.transforms = transforms

        self.cache_size = cache_size
        self.cache = {}
        self.labels = dict(zip(annotations_df['BraTS21ID'], annotations_df['MGMT_value'])) 

    def __len__(self):
        return len(self.patient_ids)

    def load_volume(self, patient_id):
        if patient_id in self.cache:
            return self.cache[patient_id]
        slices = []
        for slice_path in self.slice_paths[patient_id]:
            dicom_image = pydicom.dcmread(slice_path)
            image_2d = dicom_image.pixel_array
            resized_slice = zoom(image_2d, (64/image_2d.shape[0], 64/image_2d.shape[1]))
            slices.append(resized_slice)
           
        volume = np.stack(slices, axis=-1)
        volume = np.expand_dims(volume, axis=0)
        volume_tensor = torch.from_numpy(volume).float()

        # Print range before normalization
        # print(f"Before normalization: [{volume_tensor.min():.3f}, {volume_tensor.max():.3f}]")

        # Add normalization here
        volume_tensor = (volume_tensor - volume_tensor.min()) / (volume_tensor.max() - volume_tensor.min())
       
        # Print range after normalization
        # print(f"After normalization: [{volume_tensor.min():.3f}, {volume_tensor.max():.3f}]")
       
        
        if len(self.cache) < self.cache_size:
            self.cache[patient_id] = volume_tensor
           
        return volume_tensor

    
    def __getitem__(self, idx):
        
        patient_id = self.patient_ids[idx]
        volume = self.load_volume(patient_id)
        label = torch.tensor(self.labels[patient_id])
        
        data = {"image": volume, 
                "patient_id": patient_id,
                "label": label
               }
        
        # Add print statements for debugging
        # print(f"Before transform range: [{data['image'].min():.3f}, {data['image'].max():.3f}]")
        # print("Data keys:", data.keys())
        
        # if self.transforms:
        #     data = self.transforms(data)
            
        # print(f"After transform range: [{data['image'].min():.3f}, {data['image'].max():.3f}]")
            
        
        return data
        
    

In [13]:
train_split = splits['train']
val_split = splits['val']
test_split = splits['test']

In [14]:
# Check split data structure
print("Train split keys:", train_split.keys())
print("First few patient IDs:", train_split['patient_ids'][:3])

# Check if paths are correct
first_patient = train_split['patient_ids'][0]
print("First patient paths:", train_split['slices'][first_patient][:3])

# Create dataset and check
train_dataset = BraTsDataset(train_split, annotations_df)
sample = train_dataset[0]
print(f"Sample shape: {sample['image'].shape}")

Train split keys: dict_keys(['patient_ids', 'paths', 'slices'])
First few patient IDs: ['00316', '01007', '00003']
First patient paths: ['/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00316/FLAIR/Image-1.dcm', '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00316/FLAIR/Image-10.dcm', '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00316/FLAIR/Image-11.dcm']
Sample shape: torch.Size([1, 64, 64, 60])


## Define the Transformations

In [15]:
train_transforms = Compose([
ScaleIntensityRanged(
        keys=["image"],
        a_min=-76,
        a_max=3158,
        b_min=0.0,
        b_max=1.0,
        clip=True
    ),    
    RandFlipd(keys=["image"], spatial_axis=[0, 1], prob=0.5),
    RandRotate90d(keys=["image"], prob=0.5, spatial_axes=[0, 1]),
    RandShiftIntensityd(keys=["image"], prob=0.5, offsets=0.1),
])



In [16]:
# Add print to verify transform is in chain
print("Transforms:", train_transforms.transforms)

Transforms: (<monai.transforms.intensity.dictionary.ScaleIntensityRanged object at 0x7ee9b1a66260>, <monai.transforms.spatial.dictionary.RandFlipd object at 0x7ee9b1a67700>, <monai.transforms.spatial.dictionary.RandRotate90d object at 0x7ee9b1a674f0>, <monai.transforms.intensity.dictionary.RandShiftIntensityd object at 0x7ee9b1a65c60>)


In [17]:
val_transforms = Compose([
   ScaleIntensityd(
       keys=["image"],
       minv=0.0,
       maxv=1.0,
   ),
])

In [18]:
# Create test transforms and loader
test_transforms = Compose([
   ScaleIntensityd(
       keys=["image"],
       minv=0.0,
       maxv=1.0,
   ),
])

In [19]:
train_dataset = BraTsDataset(train_split, annotations_df, transforms=train_transforms)

In [20]:
train_dataset = BraTsDataset(train_split, annotations_df)
val_dataset = BraTsDataset(val_split, annotations_df)
test_dataset = BraTsDataset(test_split, annotations_df)

In [21]:
# Check split data structure
print("Train split keys:", train_split.keys())
print("First few patient IDs:", train_split['patient_ids'][:3])

# Check if paths are correct
first_patient = train_split['patient_ids'][0]
print("First patient paths:", train_split['slices'][first_patient][:3])

# Create dataset and check
train_dataset = BraTsDataset(train_split, annotations_df)
sample= train_dataset[0]
print(f"Train Sample shape: {sample['image'].shape}")
print()

val_dataset = BraTsDataset(val_split, annotations_df)
sample= val_dataset[0]
print(f"Validation Sample shape: {sample['image'].shape}")
print()
test_dataset = BraTsDataset(val_split, annotations_df)
sample= test_dataset[0]
print(f"Test Sample shape: {sample['image'].shape}")

Train split keys: dict_keys(['patient_ids', 'paths', 'slices'])
First few patient IDs: ['00316', '01007', '00003']
First patient paths: ['/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00316/FLAIR/Image-1.dcm', '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00316/FLAIR/Image-10.dcm', '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/00316/FLAIR/Image-11.dcm']
Train Sample shape: torch.Size([1, 64, 64, 60])

Validation Sample shape: torch.Size([1, 64, 64, 129])

Test Sample shape: torch.Size([1, 64, 64, 129])


## Define the DataLoader

In [22]:
# def custom_collate(batch):
    
#     min_val = min([item["image"].min().item() for item in batch])
#     max_val = max([item["image"].max().item() for item in batch])

#     # print(f"Collate input range: [{min_val:.3f}, {max_val:.3f}]")
    
#     max_depth = max([x["image"].shape[-1] for x in batch])  # x[0] is volume
#     padded_batch = []
#     labels = []
#     for data in batch:
#         volume = data["image"]
#         label = data["label"]
#         pad_size = max_depth - volume.shape[-1]
        
#         if pad_size > 0:
#             padded_volume = torch.nn.functional.pad(volume, (0, pad_size))
#             padded_batch.append(padded_volume)
        
#         else:
#             padded_batch.append(volume)
#         labels.append(label)
    
#     return { 
#         "image": torch.stack(padded_batch), 
#         "label": torch.tensor(labels)
#            }

In [33]:
def custom_collate(batch):
    max_depth = 32  # Fixed depth
    padded_batch = []
    labels = []
    
    for data in batch:
        volume = data['image']
        label = data['label']
        
        current_depth = volume.shape[-1]
        if current_depth < max_depth:
            pad_size = max_depth - current_depth
            padded_volume = torch.nn.functional.pad(volume, (0, pad_size))
        else:
            padded_volume = volume[..., :max_depth]
            
        padded_batch.append(padded_volume)
        labels.append(label)
    
    return {
        "image": torch.stack(padded_batch),
        "label": torch.stack(labels)
    }

In [34]:
# Create DataLoader
train_loader = DataLoader(
   train_dataset,
   batch_size=4,  # Small batch size for 3D data
   shuffle=True,
   num_workers=2,
   collate_fn=custom_collate,
   pin_memory=True  # Faster data transfer to GPU
)


val_loader = DataLoader(
   val_dataset,
   batch_size=4,  # Small batch size for 3D data
   shuffle=True,
   num_workers=2,
   collate_fn=custom_collate,
   pin_memory=True  # Faster data transfer to GPU
)

# Create DataLoader
test_loader = DataLoader(
   test_dataset,
   batch_size=4,  # Small batch size for 3D data
   shuffle=False,
   num_workers=2,
   collate_fn=custom_collate,
   pin_memory=True  # Faster data transfer to GPU
)


In [35]:
def inspect_batch(dataloader, split_name="Unknown"):
    """Inspects a batch from a given dataloader, printing shape and value range."""
    batch = next(iter(dataloader))  # Get a single batch
    
    images, labels = batch["image"], batch["label"]
    
    min_val, max_val = images.min().item(), images.max().item()
    
    print(f"\n--- {split_name} Split Batch Inspection ---")
    print(f"Image Batch Shape: {images.shape}")  # Expected: (batch_size, channels, height, width, depth)
    print(f"Label Batch Shape: {labels.shape}")  # Should match batch_size
    print(f"Value Range: [{min_val:.3f}, {max_val:.3f}]")


In [36]:
# inspect_batch(train_loader, "Train")
# inspect_batch(val_loader, "Validation")
# inspect_batch(test_loader, "Test")


# Define the model

In [37]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [38]:
# model = DenseNet121(
#     spatial_dims = 3,
#     in_channels = 1,
#     out_channels = 1
# ).to(device)

# model = DenseNet169(
#     spatial_dims = 3,
#     in_channels = 1,
#     out_channels = 1
# ).to(device)



In [46]:
# EfficientNet-B1
model = EfficientNetBN(
    model_name="efficientnet-b1",
    spatial_dims=3,
    in_channels=1,
    num_classes=1
).to(device)

In [47]:
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr = 1e-4)

## Set up the Training Loop

In [48]:
def train_epoch(model, train_loader, val_loader, epochs = 50):
    best_val_acc = 0.0
    for epoch in range(epochs):
        
        model.train()  # setting the model in train mode
        train_loss = 0.0 # initialize variable to store the sum of loss
        
        for batch in tqdm(train_loader, desc = f'Epoch {epoch+1}/{epochs} - Training'):
            # Pass through the training set for the no. of epochs
            images = batch['image'].to(device)
            labels = batch['label'].float().to(device)

        
            optimizer.zero_grad()
            outputs = model(images)
        
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            
        # Calculatie average loss for epoch
        avg_train_loss = train_loss/len(train_loader)
        
        # Validation Phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad(): # No gradient calculation needed
            for batch in tqdm(val_loader, desc = f'Epoch {epoch + 1}/{epochs} - Validation'):
                images = batch['image'].to(device)
                labels = batch['label'].float().to(device)
                
                outputs = model(images)
                
                val_loss += criterion(outputs.squeeze(), labels).item()

                predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).int()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
        val_accuracy = 100*correct/total
        avg_val_loss = val_loss/len(val_loader)
        
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
        
        # Print epoch results
        print(f'Epoch [{epoch+1}/{epochs}]')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')
        print(f'Validation Accuracy: {val_accuracy:.2f}%')
        print('-' * 50)


        
        

In [49]:
# Run training for 50 epochs
num_epochs = 30
train_epoch(model, train_loader, val_loader, epochs=num_epochs)

Epoch 1/30 - Training: 100%|██████████| 102/102 [03:45<00:00,  2.21s/it]
Epoch 1/30 - Validation: 100%|██████████| 22/22 [00:54<00:00,  2.47s/it]


Epoch [1/30]
Training Loss: 3.4046
Validation Loss: 0.8617
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 2/30 - Training: 100%|██████████| 102/102 [03:13<00:00,  1.90s/it]
Epoch 2/30 - Validation: 100%|██████████| 22/22 [00:49<00:00,  2.23s/it]


Epoch [2/30]
Training Loss: 2.2133
Validation Loss: 0.9245
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 3/30 - Training: 100%|██████████| 102/102 [03:22<00:00,  1.98s/it]
Epoch 3/30 - Validation: 100%|██████████| 22/22 [01:00<00:00,  2.76s/it]


Epoch [3/30]
Training Loss: 2.5898
Validation Loss: 0.8339
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 4/30 - Training: 100%|██████████| 102/102 [03:27<00:00,  2.04s/it]
Epoch 4/30 - Validation: 100%|██████████| 22/22 [00:50<00:00,  2.29s/it]


Epoch [4/30]
Training Loss: 2.0604
Validation Loss: 0.9808
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 5/30 - Training: 100%|██████████| 102/102 [03:15<00:00,  1.92s/it]
Epoch 5/30 - Validation: 100%|██████████| 22/22 [00:55<00:00,  2.54s/it]


Epoch [5/30]
Training Loss: 2.1129
Validation Loss: 0.6989
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 6/30 - Training: 100%|██████████| 102/102 [03:31<00:00,  2.07s/it]
Epoch 6/30 - Validation: 100%|██████████| 22/22 [00:51<00:00,  2.33s/it]


Epoch [6/30]
Training Loss: 1.9821
Validation Loss: 0.8907
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 7/30 - Training: 100%|██████████| 102/102 [03:24<00:00,  2.01s/it]
Epoch 7/30 - Validation: 100%|██████████| 22/22 [00:52<00:00,  2.40s/it]


Epoch [7/30]
Training Loss: 1.6337
Validation Loss: 0.7301
Validation Accuracy: 45.45%
--------------------------------------------------


Epoch 8/30 - Training: 100%|██████████| 102/102 [03:42<00:00,  2.18s/it]
Epoch 8/30 - Validation: 100%|██████████| 22/22 [00:53<00:00,  2.44s/it]


Epoch [8/30]
Training Loss: 1.7950
Validation Loss: 0.6969
Validation Accuracy: 55.68%
--------------------------------------------------


Epoch 9/30 - Training: 100%|██████████| 102/102 [03:28<00:00,  2.05s/it]
Epoch 9/30 - Validation: 100%|██████████| 22/22 [00:47<00:00,  2.17s/it]


Epoch [9/30]
Training Loss: 1.1090
Validation Loss: 0.7141
Validation Accuracy: 53.41%
--------------------------------------------------


Epoch 10/30 - Training: 100%|██████████| 102/102 [03:46<00:00,  2.22s/it]
Epoch 10/30 - Validation: 100%|██████████| 22/22 [00:49<00:00,  2.27s/it]


Epoch [10/30]
Training Loss: 1.5572
Validation Loss: 0.8283
Validation Accuracy: 54.55%
--------------------------------------------------


Epoch 11/30 - Training: 100%|██████████| 102/102 [03:25<00:00,  2.02s/it]
Epoch 11/30 - Validation: 100%|██████████| 22/22 [00:49<00:00,  2.23s/it]


Epoch [11/30]
Training Loss: 1.0319
Validation Loss: 2.9543
Validation Accuracy: 44.32%
--------------------------------------------------


Epoch 12/30 - Training: 100%|██████████| 102/102 [03:30<00:00,  2.06s/it]
Epoch 12/30 - Validation: 100%|██████████| 22/22 [01:03<00:00,  2.89s/it]


Epoch [12/30]
Training Loss: 1.0120
Validation Loss: 2.1087
Validation Accuracy: 55.68%
--------------------------------------------------


Epoch 13/30 - Training: 100%|██████████| 102/102 [03:42<00:00,  2.18s/it]
Epoch 13/30 - Validation: 100%|██████████| 22/22 [00:52<00:00,  2.39s/it]


Epoch [13/30]
Training Loss: 1.1381
Validation Loss: 3.8838
Validation Accuracy: 48.86%
--------------------------------------------------


Epoch 14/30 - Training: 100%|██████████| 102/102 [03:31<00:00,  2.08s/it]
Epoch 14/30 - Validation: 100%|██████████| 22/22 [00:55<00:00,  2.54s/it]


Epoch [14/30]
Training Loss: 1.0956
Validation Loss: 1.0064
Validation Accuracy: 61.36%
--------------------------------------------------


Epoch 15/30 - Training: 100%|██████████| 102/102 [03:23<00:00,  1.99s/it]
Epoch 15/30 - Validation: 100%|██████████| 22/22 [00:47<00:00,  2.17s/it]


Epoch [15/30]
Training Loss: 0.7975
Validation Loss: 0.9574
Validation Accuracy: 60.23%
--------------------------------------------------


Epoch 16/30 - Training: 100%|██████████| 102/102 [03:36<00:00,  2.12s/it]
Epoch 16/30 - Validation: 100%|██████████| 22/22 [00:53<00:00,  2.43s/it]


Epoch [16/30]
Training Loss: 0.7350
Validation Loss: 0.8940
Validation Accuracy: 61.36%
--------------------------------------------------


Epoch 17/30 - Training: 100%|██████████| 102/102 [03:26<00:00,  2.02s/it]
Epoch 17/30 - Validation: 100%|██████████| 22/22 [00:54<00:00,  2.46s/it]


Epoch [17/30]
Training Loss: 0.7473
Validation Loss: 1.4208
Validation Accuracy: 50.00%
--------------------------------------------------


Epoch 18/30 - Training: 100%|██████████| 102/102 [03:46<00:00,  2.22s/it]
Epoch 18/30 - Validation: 100%|██████████| 22/22 [00:52<00:00,  2.39s/it]


Epoch [18/30]
Training Loss: 0.7294
Validation Loss: 4.2420
Validation Accuracy: 53.41%
--------------------------------------------------


Epoch 19/30 - Training: 100%|██████████| 102/102 [03:21<00:00,  1.98s/it]
Epoch 19/30 - Validation: 100%|██████████| 22/22 [00:53<00:00,  2.45s/it]


Epoch [19/30]
Training Loss: 0.5482
Validation Loss: 1.9928
Validation Accuracy: 52.27%
--------------------------------------------------


Epoch 20/30 - Training: 100%|██████████| 102/102 [03:22<00:00,  1.98s/it]
Epoch 20/30 - Validation: 100%|██████████| 22/22 [00:54<00:00,  2.46s/it]


Epoch [20/30]
Training Loss: 0.4808
Validation Loss: 1.2981
Validation Accuracy: 53.41%
--------------------------------------------------


Epoch 21/30 - Training: 100%|██████████| 102/102 [03:24<00:00,  2.01s/it]
Epoch 21/30 - Validation: 100%|██████████| 22/22 [00:56<00:00,  2.55s/it]


Epoch [21/30]
Training Loss: 0.4719
Validation Loss: 2.2453
Validation Accuracy: 51.14%
--------------------------------------------------


Epoch 22/30 - Training: 100%|██████████| 102/102 [03:29<00:00,  2.05s/it]
Epoch 22/30 - Validation: 100%|██████████| 22/22 [00:52<00:00,  2.39s/it]


Epoch [22/30]
Training Loss: 0.5147
Validation Loss: 1.7665
Validation Accuracy: 56.82%
--------------------------------------------------


Epoch 23/30 - Training: 100%|██████████| 102/102 [03:37<00:00,  2.13s/it]
Epoch 23/30 - Validation: 100%|██████████| 22/22 [00:58<00:00,  2.66s/it]


Epoch [23/30]
Training Loss: 0.6187
Validation Loss: 2.1586
Validation Accuracy: 47.73%
--------------------------------------------------


Epoch 24/30 - Training: 100%|██████████| 102/102 [03:45<00:00,  2.21s/it]
Epoch 24/30 - Validation: 100%|██████████| 22/22 [00:50<00:00,  2.30s/it]


Epoch [24/30]
Training Loss: 0.5600
Validation Loss: 1.5631
Validation Accuracy: 53.41%
--------------------------------------------------


Epoch 25/30 - Training: 100%|██████████| 102/102 [03:17<00:00,  1.94s/it]
Epoch 25/30 - Validation: 100%|██████████| 22/22 [00:53<00:00,  2.41s/it]


Epoch [25/30]
Training Loss: 0.4359
Validation Loss: 1.4904
Validation Accuracy: 42.05%
--------------------------------------------------


Epoch 26/30 - Training: 100%|██████████| 102/102 [03:20<00:00,  1.96s/it]
Epoch 26/30 - Validation: 100%|██████████| 22/22 [01:16<00:00,  3.46s/it]


Epoch [26/30]
Training Loss: 0.5032
Validation Loss: 1.4833
Validation Accuracy: 52.27%
--------------------------------------------------


Epoch 27/30 - Training: 100%|██████████| 102/102 [03:37<00:00,  2.13s/it]
Epoch 27/30 - Validation: 100%|██████████| 22/22 [00:52<00:00,  2.38s/it]


Epoch [27/30]
Training Loss: 0.2751
Validation Loss: 1.5049
Validation Accuracy: 51.14%
--------------------------------------------------


Epoch 28/30 - Training: 100%|██████████| 102/102 [03:45<00:00,  2.21s/it]
Epoch 28/30 - Validation: 100%|██████████| 22/22 [00:59<00:00,  2.69s/it]


Epoch [28/30]
Training Loss: 0.4440
Validation Loss: 2.4459
Validation Accuracy: 47.73%
--------------------------------------------------


Epoch 29/30 - Training: 100%|██████████| 102/102 [03:57<00:00,  2.32s/it]
Epoch 29/30 - Validation: 100%|██████████| 22/22 [00:56<00:00,  2.57s/it]


Epoch [29/30]
Training Loss: 0.3920
Validation Loss: 2.1937
Validation Accuracy: 48.86%
--------------------------------------------------


Epoch 30/30 - Training: 100%|██████████| 102/102 [03:20<00:00,  1.97s/it]
Epoch 30/30 - Validation: 100%|██████████| 22/22 [00:51<00:00,  2.35s/it]

Epoch [30/30]
Training Loss: 0.3449
Validation Loss: 1.6413
Validation Accuracy: 45.45%
--------------------------------------------------





In [51]:
def test_model(model, test_loader, criterion):
    # Load best model weights
    model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
    model.eval()  # Set to evaluation mode
    
    test_loss = 0.0
    correct = 0
    total = 0
    
    # For storing predictions and true labels
    all_predictions = []
    all_labels = []
    
    test_pbar = tqdm(test_loader, desc='Testing')
    with torch.no_grad():
        for batch in test_pbar:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(images)
            test_loss += criterion(outputs.squeeze(), labels.float()).item()
            
            predicted = (torch.sigmoid(outputs.squeeze()) > 0.5).int()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Store predictions and labels
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            current_acc = 100 * correct / total
            test_pbar.set_postfix({'acc': f'{current_acc:.2f}%'})
    
    # Calculate final metrics
    test_accuracy = 100 * correct / total
    avg_test_loss = test_loss / len(test_loader)
    
    print('\nTest Results:')
    print(f'Test Loss: {avg_test_loss:.4f}')
    print(f'Test Accuracy: {test_accuracy:.2f}%')
    
    return test_accuracy, avg_test_loss, all_predictions, all_labels


In [52]:
# Run testing
test_accuracy, test_loss, predictions, true_labels = test_model(model, test_loader, criterion)

  model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
Testing: 100%|██████████| 22/22 [00:54<00:00,  2.50s/it, acc=61.36%]


Test Results:
Test Loss: 1.0064
Test Accuracy: 61.36%





In [45]:
print('Test Accuracy: ',test_accuracy)
print('Test Loss: ',test_loss)

Test Accuracy:  61.36363636363637
Test Loss:  0.879914247176864
