In [2]:
!pip install natsort

Collecting natsort
  Downloading natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Downloading natsort-8.4.0-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.4.0


In [3]:
import gc

# Clear memory
gc.collect()
import torch

torch.cuda.empty_cache()

In [4]:
import sys
sys.path.append('/kaggle/input/resnet-3d')

In [5]:
sys.path.append('/kaggle/input/resize-and-standardization')

In [6]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
import sys
from resize_standard import produce_label
from resize_standard import resize
from resize_standard import standardize_dicom
from resnet_3d import ResNet3D
import json
from tqdm import tqdm  # Import tqdm for progress bars

In [8]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [9]:
data_dir='/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'

In [10]:
with open('/kaggle/input/label-nec/list.json', 'r') as f:
    total_size = json.load(f)

In [11]:
with open('/kaggle/input/label-nec/json_label_coordinates (2).json', 'r') as f:
    label_coordinates = json.load(f)

In [12]:
series_pat_list=[]
for i in os.listdir(f'{data_dir}/train_images'):
    for j in os.listdir(f'{data_dir}/train_images/{i}'):
        new_list=[i,j]
        series_pat_list.append(new_list)
series_pat_list.remove(['3008676218','542282425'])
series_pat_list.remove(['3008676218','3636216534'])
series_pat_list.remove(['3637444890','3892989905'])


In [13]:
series_list=[]
for j in range(len(series_pat_list)):
    series_list.append(series_pat_list[j][1])
series_list=np.array(series_list)

In [14]:
convv3d=standardize_dicom(series_pat_list,'866293114',data_dir)
print(convv3d.shape)
convvv3d=resize(convv3d)

torch.Size([18, 512, 512])


In [15]:
print(convvv3d.squeeze(0).shape)

torch.Size([1, 25, 600, 600])


In [16]:
class DICOMDataset(Dataset):
    def __init__(self, series_list, produce_label, data_dir,transform=None):
        self.series_list = series_list
        self.produce_label = produce_label
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.series_list)

    def __getitem__(self, idx):
        
        series_id = self.series_list[idx]
        
        # For simplicity, we'll just take the first series_id
        dicom_3d=standardize_dicom(series_pat_list,series_id,data_dir)
        final_input=resize(dicom_3d).squeeze(0)
        final_label=produce_label(int(series_id),label_coordinates,total_size)
        final_label=torch.tensor(final_label)
        final_input=final_input.clone().detach()
        
        if self.transform:
            final_input = self.transform(final_input)
        return final_input, final_label


In [18]:
model = ResNet3D(pretrained=True)  # This is an example; adjust as needed for your model

In [19]:
model.fc=nn.Linear(2048,30)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [20]:
for param in model.parameters():
    param.requires_grad = False

In [21]:
def unfreeze_layers(model, num_layers_to_unfreeze):
    """
    Unfreeze the last `num_layers_to_unfreeze` layers of the model.
    """
    layers = list(model.children())
    for layer in layers[-num_layers_to_unfreeze:]:
        for param in layer.parameters():
            param.requires_grad = True

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)


In [31]:
num_epochs=10
best_val_loss = float('inf')
save_dir='/kaggle/working/'
os.makedirs(save_dir, exist_ok=True)  # Create the directory if it does not exist
best_model_path = os.path.join(save_dir, 'best_model.pth')
for train_index, test_index in kf.split(series_list):
    train_series_list, test_series_list = series_list[train_index], series_list[test_index]
    train_dataset = DICOMDataset(train_series_list, produce_label, data_dir)
    val_dataset = DICOMDataset(test_series_list, produce_label, data_dir)
    train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=3, shuffle=False, num_workers=4)
    # Training loop for this fold
    
    for epoch in range(10):
        if epoch > -1:  # Start unfreezing layers after the first epoch
            unfreeze_layers(model, epoch)  # Gradually unfreeze layers

    
    # Update the optimizer to include all trainable parameters
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)    
        model.train()
        running_loss=0.0
        for batch in tqdm(train_loader):
            inputs, labels = batch
            labels=labels.float()
            inputs = inputs.float()  # Convert inputs to float type
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        model.eval()
        torch.cuda.empty_cache()

        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                inputs, labels = batch
                inputs = inputs.float()  # Convert inputs to float type
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        torch.cuda.empty_cache()
        val_loss /= len(val_loader)
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss/len(train_loader)}, Val Loss: {val_loss}')        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f'Best model saved with validation loss: {best_val_loss}')
best_model_weights = torch.load(best_model_path)
model.load_state_dict(best_model_weights)


 10%|█         | 176/1678 [02:59<25:28,  1.02s/it]


KeyboardInterrupt: 

In [30]:
loss.item()
val_loss/len(val_loader)

8.522440585668171e-05