In [28]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader, Subset
import torch.optim as optim
import torchvision
import torchvision.models as models

import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import random
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve

In [31]:
def show(image, label):
    plt.imshow(image, cmap='gray')  # Use 'gray' since image is in 'L' mode
    plt.title(f"Label: {label}")
    plt.axis('off')
    plt.show()

In [32]:
BATCH_SIZE = 8
LR = 0.001
EPOCHS = 10

main_dest_dir = '/kaggle/working/'
source_base_dir = '/kaggle/input/processed-lungs'

class ChestXRayDataset(Dataset):
    def __init__(self):
        self.image_paths = []
        self.labels = []
        self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                     
        self.transform_positive = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            self.normalize
        ])
        
        self.transform_negative = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            self.normalize
        ])
        
        for cohort_df, label_value in [(nodule_df, 1.0), (second_df, 0.0)]:
            source_folder = 'processed_lungs/nodule' if label_value == 1 else 'processed_lungs/non_nodule'
            for _, row in cohort_df.iterrows():
                image_filename = row['filename']
                img_path = os.path.join(source_base_dir,
                                        source_folder,
                                        image_filename)
        
                self.image_paths.append(img_path)
                self.labels.append(label_value)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label = self.labels[index]
        image = Image.open(img_path).convert('RGB')
        if label == 1.0:
            image = self.transform_positive(image)
        else:
            image = self.transform_negative(image)
            
        return image, torch.tensor(label, dtype=torch.float32)

    def __len__(self):
        return len(self.image_paths)
        
    def tackle_idxs(self, idxs):
        image_paths_temp = []
        labels_temp = []
        
        for i in idxs:
            label = self.labels[i]
            img_path = self.image_paths[i]
            
            image_paths_temp.append(img_path)
            labels_temp.append(label)
        
        combined = list(zip(image_paths_temp, labels_temp))
        random.shuffle(combined)
        self.image_paths, self.labels = map(list, zip(*combined))

In [33]:
class Resnet34(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.resnet34 = models.resnet34(weights='IMAGENET1K_V1')

        for param in self.resnet34.parameters():
            param.requires_grad = False

        to_unfreeze = ['layer3', 'layer4', 'fc']
        for name, param in self.resnet34.named_parameters():
            if any(name.startswith(layer) for layer in to_unfreeze):
                param.requires_grad = True

        num_ftrs = self.resnet34.fc.in_features
        self.resnet34.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.resnet34(x)

In [34]:
def train_and_evaluate(model, model_name, train_loader, val_loader, test_loader, pos_weight):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=1)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    
    print(f"\n----- Training {model_name} -----")
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.view(-1), labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            print(train_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs.view(-1), labels)
                val_loss += loss.item() * images.size(0)
                print(val_loss)
        
        avg_val_loss = val_loss / len(val_loader.dataset)
        scheduler.step(avg_val_loss)
        
        print(f'Epoch {epoch+1}/{EPOCHS}')
        print(f'Train Loss: {train_loss/len(train_loader.dataset):.4f}')
        print(f'Val Loss: {avg_val_loss:.4f}\n')
    
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = torch.sigmoid(model(images))
            probs = outputs.view(-1).cpu().numpy()
            predicted = (probs >= 0.5).astype(float)
            all_preds.extend(predicted)
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    auc = roc_auc_score(all_labels, all_probs)
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    
    print(f'----- {model_name} Test Metrics -----')
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')
    print(f'ROC AUC: {auc:.4f}\n')
    
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (area = {auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic - {model_name}')
    plt.legend(loc="lower right")
    plt.savefig(f"{model_name}_roc_curve_new1 (2).png")
    plt.close()
    
    torch.save(model.state_dict(), f"{model_name}_final_model_new1 (2).pth")
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}