In [1]:
import torch
import torchvision
from torchvision import transforms
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import torch.profiler
import matplotlib.pyplot as plt
import time

In [2]:
# Load File function
def load_file(path):
    return np.load(path).astype(np.float32)

# Transformations for training and validation
train_transforms = transforms.Compose([
    transforms.ToTensor(),  # Convert Loaded Numpy array to tensor
    transforms.Normalize(0.4903962485384803, 0.24795070634161256),
    transforms.RandomAffine(degrees=(-5, 5),
                            translate=(0, 0.05),
                            scale=(0.9, 1.1)),
    transforms.RandomResizedCrop((224, 224), scale=(0.35, 1))
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),  # Convert Loaded Numpy array to tensor
    transforms.Normalize(0.4903962485384803, 0.24795070634161256),
])

In [3]:
default_path = "C:\\Users\\write\\Desktop\\Medical_Images\\4_Projects\\Pneumonia_Classification\\rsna-pneumonia-detection-challenge\\"

In [4]:
# Dataset loading
train_dataset = torchvision.datasets.DatasetFolder(default_path + "Processed/train/",
                                                   loader=load_file,
                                                   extensions="npy",
                                                   transform=train_transforms)

val_dataset = torchvision.datasets.DatasetFolder(default_path + "Processed/val/",
                                                 loader=load_file,
                                                 extensions="npy",
                                                 transform=val_transforms)

In [5]:
batch_size = 64
num_workers = 4

In [6]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

Number of training samples: 24000
Number of validation samples: 2684


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, persistent_workers=True)

In [None]:
# Model definition
class PneumoniaModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)  # Changed color channel to 1
        self.model.fc = torch.nn.Linear(in_features=512, out_features=1, bias=True)

    def forward(self, data):
        return self.model(data)

In [None]:
# Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PneumoniaModel().to(device)

In [None]:
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([3])).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Profiler settings for performance tracking
profiler = torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
)

In [None]:
# Training loop
num_epochs = 35
for epoch in range(num_epochs):
    model.train()
    start_time = time.time()  # Track epoch start time
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Move data and targets to GPU
        data, targets = data.to(device), targets.to(device)

        # Zero gradients, perform a backward pass, and update the weights
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Print loss for every 100th batch
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item()}")

        # Profiler step
        profiler.step()

        # Optional: Add timing information to monitor each batch time
        batch_time = time.time() - start_time
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx} processed in {batch_time:.4f} seconds")

    print(f"Epoch {epoch+1} completed in {time.time() - start_time:.2f} seconds")