In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader, Subset
import torch.optim as optim
import torchvision
import torchvision.models as models

import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import random
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score, roc_curve

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
def show(image, label):
    plt.imshow(image, cmap='gray')  # Use 'gray' since image is in 'L' mode
    plt.title(f"Label: {label}")
    plt.axis('off')
    plt.show()

In [4]:
df = pd.read_csv('/kaggle/input/processed-lungs/processed_lungs/labels.csv')

nodule_df = df[df['label'] == 1]
second_df = df[df['label'] == 0]

print("we print nodule ")
print(nodule_df)
print("now second one")
print(second_df)

we print nodule 
                   filename  label
0     00000004_000_lung.png      1
1     00000008_002_lung.png      1
2     00000013_025_lung.png      1
3     00000017_000_lung.png      1
4     00000021_000_lung.png      1
...                     ...    ...
6326  00030703_001_lung.png      1
6327  00030715_000_lung.png      1
6328  00030722_000_lung.png      1
6329  00030726_000_lung.png      1
6330  00030793_000_lung.png      1

[6331 rows x 2 columns]
now second one
                    filename  label
6331   00024601_000_lung.png      0
6332   00022677_004_lung.png      0
6333   00013534_009_lung.png      0
6334   00009609_012_lung.png      0
6335   00022058_001_lung.png      0
...                      ...    ...
12659  00003596_005_lung.png      0
12660  00018120_001_lung.png      0
12661  00012003_001_lung.png      0
12662  00020434_003_lung.png      0
12663  00021006_008_lung.png      0

[6333 rows x 2 columns]


In [5]:
BATCH_SIZE = 8
LR = 0.001
EPOCHS = 10

main_dest_dir = '/kaggle/working/'
source_base_dir = '/kaggle/input/processed-lungs'

class ChestXRayDataset(Dataset):
    def __init__(self):
        self.image_paths = []
        self.labels = []
        self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                     
        self.transform_positive = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            self.normalize
        ])
        
        self.transform_negative = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            self.normalize
        ])
        
        for cohort_df, label_value in [(nodule_df, 1.0), (second_df, 0.0)]:
            source_folder = 'processed_lungs/nodule' if label_value == 1 else 'processed_lungs/non_nodule'
            for _, row in cohort_df.iterrows():
                image_filename = row['filename']
                img_path = os.path.join(source_base_dir,
                                        source_folder,
                                        image_filename)
        
                self.image_paths.append(img_path)
                self.labels.append(label_value)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label = self.labels[index]
        image = Image.open(img_path).convert('RGB')
        if label == 1.0:
            image = self.transform_positive(image)
        else:
            image = self.transform_negative(image)
            
        return image, torch.tensor(label, dtype=torch.float32)

    def __len__(self):
        return len(self.image_paths)
        
    def tackle_idxs(self, idxs):
        image_paths_temp = []
        labels_temp = []
        
        for i in idxs:
            label = self.labels[i]
            img_path = self.image_paths[i]
            
            image_paths_temp.append(img_path)
            labels_temp.append(label)
        
        combined = list(zip(image_paths_temp, labels_temp))
        random.shuffle(combined)
        self.image_paths, self.labels = map(list, zip(*combined))

In [6]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.alexnet = models.alexnet(weights='IMAGENET1K_V1')
        
        for param in self.alexnet.parameters():
            param.requires_grad = False

        for name, param in self.alexnet.named_parameters():
            if name.startswith("classifier"):
                param.requires_grad = True

        num_ftrs = self.alexnet.classifier[6].in_features
        self.alexnet.classifier[6] = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.alexnet(x)


In [7]:
def train_and_evaluate(model, model_name, train_loader, val_loader, test_loader, pos_weight):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=1)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    
    print(f"\n----- Training {model_name} -----")
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.view(-1), labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            print(train_loss)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs.view(-1), labels)
                val_loss += loss.item() * images.size(0)
                print(val_loss)
        
        avg_val_loss = val_loss / len(val_loader.dataset)
        scheduler.step(avg_val_loss)
        
        print(f'Epoch {epoch+1}/{EPOCHS}')
        print(f'Train Loss: {train_loss/len(train_loader.dataset):.4f}')
        print(f'Val Loss: {avg_val_loss:.4f}\n')
    
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = torch.sigmoid(model(images))
            probs = outputs.view(-1).cpu().numpy()
            predicted = (probs >= 0.5).astype(float)
            all_preds.extend(predicted)
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    auc = roc_auc_score(all_labels, all_probs)
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    
    print(f'----- {model_name} Test Metrics -----')
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')
    print(f'ROC AUC: {auc:.4f}\n')
    
    plt.figure()
    plt.plot(fpr, tpr, label=f'ROC curve (area = {auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic - {model_name}')
    plt.legend(loc="lower right")
    plt.savefig(f"{model_name}_roc_curve_new1 (2).png")
    plt.close()
    
    torch.save(model.state_dict(), f"{model_name}_final_model_new1 (2).pth")
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}

In [8]:
def main():
    dataset = ChestXRayDataset()
    n = len(dataset)
    labels = dataset.labels
    print(n)
    idxs = list(range(n))
    train_idxs, testval_idxs, train_labels, testval_labels = train_test_split(
        idxs,
        labels,
        test_size=0.30,
        stratify=labels,
        random_state=42
    )

    relative_val_size = 0.15 / 0.30
    test_idxs, val_idxs, test_labels, val_labels = train_test_split(
        testval_idxs,
        testval_labels,
        test_size=relative_val_size,
        stratify=testval_labels,
        random_state=42
    )

    print(len(train_idxs))
    print(len(val_idxs))
    print(len(test_idxs))
    
    train_dataset = ChestXRayDataset()
    train_dataset.tackle_idxs(train_idxs)
    
    val_dataset = ChestXRayDataset()
    val_dataset.tackle_idxs(val_idxs)
    
    test_dataset = ChestXRayDataset()
    test_dataset.tackle_idxs(test_idxs)
    
    pos_count = sum(train_dataset.labels)
    neg_count = len(train_dataset) - pos_count
    pos_weight = torch.tensor([neg_count / pos_count])
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print(pos_count)
    print(neg_count)
    
    model = AlexNet()
    
    result = train_and_evaluate(model, 'AlexNet', train_loader, val_loader, test_loader, pos_weight)
    print("----- Overall Results -----")
    print(f"AlexNet: {result}")
        
if __name__ == '__main__':
    main()

12664
8864
1900
1900
4431.0
4433.0


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 188MB/s]  



----- Training AlexNet -----
5.807863235473633
13.881858825683594
20.404769897460938
30.53816795349121
45.97611713409424
49.61279845237732
69.46787428855896
82.78124785423279
88.73335003852844
99.61401391029358
121.79792618751526
127.31547665596008
131.5816524028778
139.23216032981873
144.90542197227478
153.14748644828796
162.0348765850067
168.24659371376038
175.01239037513733
188.62418675422668
194.04767537117004
197.37438011169434
205.1542525291443
212.2874732017517
220.30302000045776
233.34508752822876
238.77501344680786
244.60324907302856
251.8602418899536
259.2830820083618
264.90553092956543
272.72658348083496
279.6526527404785
290.5739860534668
298.7773714065552
305.5958423614502
312.7508397102356
316.8645706176758
327.09261989593506
336.2785243988037
341.1335072517395
345.3321099281311
350.73098134994507
359.3203568458557
366.26447057724
372.22820806503296
382.9045310020447
389.1930584907532
396.05222368240356
402.19319105148315
409.272123336792
415.30473947525024
420.141520023