In [None]:
### ORIGINAL ####
# -------
# Install
# -------
# !pip install torch torchvision torchmetrics

# -------------------------
# Import required libraries
# -------------------------
# Data loading
import random
import os
import shutil
import numpy as np
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# Train model
import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import torch.optim as optim
# Evaluate model
from torchmetrics import Accuracy, F1Score
# Check for GPU availability
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print("Using device:", device)
# Check for MPS availability
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Using device:", device)
Using device: mps


#---------------------------------------------------------------------------------
# Move 50 random images per class from the training set to create a validation set
#---------------------------------------------------------------------------------
def move_files(src_class_dir, dest_class_dir, n=50):
    if not os.path.exists(dest_class_dir):
        os.makedirs(dest_class_dir)
    files = os.listdir(src_class_dir)
    random_files = random.sample(files, n)
    for f in random_files:
        shutil.move(os.path.join(src_class_dir, f), os.path.join(dest_class_dir, f))
if not os.path.exists('data/chestxrays/val'):
    move_files('data/chestxrays/train/NORMAL', 'data/chestxrays/val/NORMAL')
    move_files('data/chestxrays/train/PNEUMONIA', 'data/chestxrays/val/PNEUMONIA')


#------------------------------------
# Transformations and create datasets
#------------------------------------
# Define the transformations to apply to the images for use with ResNet-50.
# The images need to be normalized to the same domain as the original training data of ResNet-50 network.
# Normalize the X-rays using transforms. A Normalize function that takes as input the means and
# standard deviations of the three color channels, (R,G,B), from the original ResNet-50 training dataset.
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]
# Training transforms: Add horizontal flip for augmentation
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std)
])
# Validation and test transforms: no augmentation, just normalization
val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std)
])
# Create datasets
train_dataset = ImageFolder('data/chestxrays/train', transform=train_transform)
val_dataset = ImageFolder('data/chestxrays/val', transform=val_test_transform)
test_dataset = ImageFolder('data/chestxrays/test', transform=val_test_transform)
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
print("Training set size:", len(train_dataset))
print("Validation set size:", len(val_dataset))
print("Test set size:", len(test_dataset))


#----------------------
# Instantiate the model
#----------------------
# Load the pre-trained ResNet-50 model with new weights with accuracy 80.858%
resnet50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
#-----------------
# Modify the model
#-----------------
# Freeze the parameters of the model
for param in resnet50.parameters():
    param.requires_grad = False
# Modify the final layer for binary classification
resnet50.fc = nn.Linear(resnet50.fc.in_features, 1)
# Set the model to ResNet-50
model = resnet50
# Move the model to the selected device (GPU, MPS, or CPU)
model.to(device)
#-------------------------
# Define the training loop
#-------------------------
# Training function with validation and early stopping
def train_with_validation(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=5):
    best_val_loss = float('inf')
    epochs_no_improve = 0
for epoch in range(num_epochs):
        # Training Phase
        model.train()
        running_loss = 0.0
        running_accuracy = 0
for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            labels = labels.float().unsqueeze(1)
            
            # Use mixed precision training for forward pass and loss computation
            with torch.amp.autocast(device_type=device.type):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
preds = torch.sigmoid(outputs) > 0.5
            running_loss += loss.item() * inputs.size(0)
            running_accuracy += torch.sum(preds == labels.data)
train_loss = running_loss / len(train_loader.dataset)
        # train_acc = running_accuracy.double() / len(train_loader.dataset) # For GPU/CPU
        train_acc = running_accuracy.float() / len(train_loader.dataset) # For MPS
# Validation Phase
        model.eval()
        val_loss = 0.0
        val_accuracy = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                labels = labels.float().unsqueeze(1)
                
                # Use mixed precision training for forward pass and loss computation
                with torch.amp.autocast(device_type=device.type):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                
                preds = torch.sigmoid(outputs) > 0.5
                val_loss += loss.item() * inputs.size(0)
                val_accuracy += torch.sum(preds == labels.data)
val_loss = val_loss / len(val_loader.dataset)
        # val_acc = val_accuracy.double() / len(val_loader.dataset)  # For GPU/CPU
        val_acc = val_accuracy.float() / len(val_loader.dataset) # For MPS
print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
# Step the scheduler with the validation loss
        scheduler.step(val_loss)
# Early Stopping Check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            # Save model
            torch.save(model.state_dict(), 'model.pth')
        else:
            epochs_no_improve += 1
if epochs_no_improve >= patience:
            print("Early stopping triggered")
            break
print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))


#--------------------
# Fine-tune the model
#--------------------
# Set up loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
# Use ReduceLROnPlateau scheduler to reduce LR if validation loss stagnates
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
# Decay lr by 10% every epoch (alternative scheduler)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
# Train the model with early stopping and validation
train_with_validation(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=5)



#-------------------
# Evaluate the model
#-------------------
# Set model to evaluation mode
model.eval()
# Load the best model weights
model.load_state_dict(torch.load('model.pth'))
# Initialize metrics for accuracy and F1 score
accuracy_metric = Accuracy(task="binary")
f1_metric = F1Score(task="binary")
# Create lists store all predictions and labels
all_preds = []
all_labels = []
# Disable gradient calculation for evaluation
with torch.no_grad():
  for inputs, labels in test_loader:
    # Move inputs and labels to the device
    inputs, labels = inputs.to(device), labels.to(device)
    
    # Forward pass
    outputs = model(inputs)
    preds = torch.sigmoid(outputs).round() # Round to 0 or 1
    
    # Extend the lists with predictions and labels
    all_preds.extend(preds.cpu().tolist())
    all_labels.extend(labels.unsqueeze(1).cpu().tolist())
# Convert lists to tensors
all_preds = torch.tensor(all_preds)
all_labels = torch.tensor(all_labels)
# Compute metrics for the entire test set
test_acc = accuracy_metric(all_preds, all_labels).item()
test_f1 = f1_metric(all_preds, all_labels).item()
print(f"Test accuracy: {test_acc:.3f}")
print(f"Test F1-score: {test_f1:.3f}")
# Save the model
torch.save(model.state_dict(), 'model_final.pth')
# Save the model with TorchScript for deployment
model_scripted = torch.jit.script(model)  # Scripting the model
torch.jit.save(model_scripted, 'model_scripted.pt')
# Save the model with TorchScript for deployment
model_traced = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to(device))  # Tracing the model
torch.jit.save(model_traced, 'model_traced.pt')
# Load the model for inference
# Load the model with TorchScript for deployment
# model_loaded = torch.jit.load('model_scripted.pt')
# model_loaded = torch.jit.load('model_traced.pt')
# model_loaded.eval()
# # Example inference with the loaded model


In [None]:
#### COPPILOT ###
#!/usr/bin/env python3
"""
Improved training script for binary classification on chest X-ray images.
This script:
  - Creates a validation set by moving random files
  - Applies proper transforms (and optional augmentation)
  - Loads a pre-trained ResNet-50 and fine-tunes its final layer
  - Trains using early stopping with a validation loop and learning rate scheduler
  - Evaluates the model on a test set with accuracy and F1-score metrics
  - Saves the trained model in both state dict and TorchScript formats
"""
import os
import random
import shutil
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.models import resnet50, ResNet50_Weights
from torchmetrics import Accuracy, F1Score
# from tqdm import tqdm  # Optional: progress bar for loops
from tqdm.notebook import tqdm  # Notebook-friendly progress bar

# ---------------------------
# Set random seeds for reproducibility
# ---------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# ---------------------------
# Device selection
# ---------------------------
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using device:", device)


# ---------------------------
# Utility Function: Move files for validation set creation
# ---------------------------
def move_files(src_class_dir: str, dest_class_dir: str, n: int = 50) -> None:
    """
    Moves n random files from the source directory to the destination directory.
    """
    if not os.path.exists(dest_class_dir):
        os.makedirs(dest_class_dir)
    files = os.listdir(src_class_dir)
    if len(files) < n:
        raise ValueError(f"Not enough files in {src_class_dir} to move {n} files.")
    random_files = random.sample(files, n)
    for f in random_files:
        shutil.move(os.path.join(src_class_dir, f), os.path.join(dest_class_dir, f))
# ---------------------------
# Data preparation: Create validation set if needed
# ---------------------------
def prepare_validation_set():
    if not os.path.exists('data/chestxrays/val'):
        move_files('data/chestxrays/train/NORMAL', 'data/chestxrays/val/NORMAL', n=50)
        move_files('data/chestxrays/train/PNEUMONIA', 'data/chestxrays/val/PNEUMONIA', n=50)


# ---------------------------
# Data transforms and datasets
# ---------------------------
def prepare_datasets():
    """
    Defines image transformations and creates training, validation, and test datasets.
    """
    # ResNet-50 expects images normalized using its original channel means and stds.
    transform_mean = [0.485, 0.456, 0.406]
    transform_std = [0.229, 0.224, 0.225]

    # Define augmentation for training and simple normalization for validation and testing
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])
    val_test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

    train_dataset = datasets.ImageFolder('data/chestxrays/train', transform=train_transform)
    val_dataset = datasets.ImageFolder('data/chestxrays/val', transform=val_test_transform)
    test_dataset = datasets.ImageFolder('data/chestxrays/test', transform=val_test_transform)

    print("Training set size:", len(train_dataset))
    print("Validation set size:", len(val_dataset))
    print("Test set size:", len(test_dataset))

    return train_dataset, val_dataset, test_dataset


# ---------------------------
# Build DataLoaders
# ---------------------------
def prepare_dataloaders(train_dataset, val_dataset, test_dataset, batch_size: int = 32):
    """
    Loads a pre-trained ResNet-50 model, freezes its parameters,
    and replaces the final fully connected layer for binary classification.
    """
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader


# ---------------------------
# Model setup: Load and modify ResNet-50 for binary classification
# ---------------------------
def build_model() -> torch.nn.Module:
    model_resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    # Freeze all parameters
    for param in model_resnet.parameters():
        param.requires_grad = False

    # Replace the final classification layer with one that outputs a single logit (for binary classification)
    num_features = model_resnet.fc.in_features
    model_resnet.fc = nn.Linear(num_features, 1)
    model_resnet = model_resnet.to(device)
    return model_resnet


# ---------------------------
# Training loop with validation and early stopping
# ---------------------------
def train_with_validation(model: torch.nn.Module,
                          train_loader: DataLoader,
                          val_loader: DataLoader,
                          criterion,
                          optimizer,
                          scheduler,
                          device: torch.device,
                          num_epochs: int = 50,
                          patience: int = 5) -> None:
            patience: int = 5) -> None:
    """
    Trains the model with a training loop that includes validation after each epoch. 
    Saves the model if validation loss improves.
    Early stops training if the validation loss does not improve for 'patience' epochs.
    """
    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        # Training phase
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            # Adjust labels shape: add extra dimension and cast to float
            labels = labels.float().unsqueeze(1)

            with torch.amp.autocast(device_type=device.type):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            preds = (torch.sigmoid(outputs) > 0.5).float()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.float() / len(train_loader.dataset)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation", leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                labels = labels.float().unsqueeze(1)
                with torch.amp.autocast(device_type=device.type):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                preds = (torch.sigmoid(outputs) > 0.5).float()
                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels.data)

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_corrects.float() / len(val_loader.dataset)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Step LR scheduler based on validation loss
        scheduler.step(val_loss)

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            # Save best model
            torch.save(model.state_dict(), 'model.pth')
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping triggered")
            break

    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))


# ---------------------------
# Model evaluation on test set
# ---------------------------
def evaluate_model(model: torch.nn.Module, test_loader: DataLoader):
    """
    Evaluates the model on the test set and prints Accuracy and F1-score.
    """
    model.eval()
    accuracy_metric = Accuracy(task="binary")
    f1_metric = F1Score(task="binary")
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.sigmoid(outputs).round()  # Binary prediction: round to 0 or 1
            all_preds.extend(preds.cpu().numpy().tolist())
            # Adjust labels to a comparable shape
            all_labels.extend(labels.unsqueeze(1).cpu().numpy().tolist())

    # Convert lists to tensors to compute final metrics
    all_preds_tensor = torch.tensor(all_preds)
    all_labels_tensor = torch.tensor(all_labels)

    test_acc = accuracy_metric(all_preds_tensor, all_labels_tensor).item()
    test_f1 = f1_metric(all_preds_tensor, all_labels_tensor).item()

    print(f"Test accuracy: {test_acc:.3f}")
    print(f"Test F1-score: {test_f1:.3f}")
    return test_acc, test_f1


# ---------------------------
# Save model in different formats for deployment
# ---------------------------
def save_model_formats(model: torch.nn.Module, device: torch.device):
    """
    Saves the final model state as a state dict, as well as in TorchScript's scripted and traced formats.
    """
    # Save the final state dict
    torch.save(model.state_dict(), 'model_final.pth')
    # Save a scripted model
    model_scripted = torch.jit.script(model)
    torch.jit.save(model_scripted, 'model_scripted.pt')
    # Save a traced model using a dummy input (assuming images are 224x224 with 3 channels)
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    model_traced = torch.jit.trace(model, dummy_input)
    torch.jit.save(model_traced, 'model_traced.pt')
    print("Model saved in state dict, scripted, and traced formats.")


# # ---------------------------
# # Main routine: Prepare data, build model, train, evaluate, and save.
# # ---------------------------
# def main():
#     prepare_validation_set()
#     train_dataset, val_dataset, test_dataset = prepare_datasets()
#     train_loader, val_loader, test_loader = prepare_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=32)
#     model = build_model()

#     # Set up loss, optimizer, and LR scheduler
#     criterion = nn.BCEWithLogitsLoss()
#     optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

#     train_with_validation(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=5)

#     # Load the best model based on validation loss and evaluate on the test set
#     model.load_state_dict(torch.load('model.pth'))
#     evaluate_model(model, test_loader)

#     # Save model formats for deployment
#     save_model_formats(model, device)


if __name__ == "__main__":
    main() 


### Main Execution Pipeline ###

# Prepare the validation set if needed
prepare_validation_set()

# Create datasets and data loaders
train_dataset, val_dataset, test_dataset = prepare_datasets()
train_loader, val_loader, test_loader = prepare_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=32)
# Build the model for binary classification
model = build_model()

# Set up loss, optimizer, and learning rate scheduler
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# Train the model with validation and early stopping
train_with_validation(model, train_loader, val_loader, criterion, optimizer, scheduler, device,
                      num_epochs=50, patience=5)

# Load the best model based on validation loss
model.load_state_dict(torch.load('model.pth'))
# Evaluate the model on the test set
evaluate_model(model, test_loader)

# Save the model in various formats for later deployment
save_model_formats(model, device)

In [None]:
#---------------------------------------------------

Pneumonia is one of the leading respiratory illnesses worldwide, and its timely and accurate diagnosis is crucial for effective treatment. Manually reviewing chest X-rays plays a critical role in this process, but AI can significantly expedite and enhance assessments.

In this project, I explored the ability of a deep learning model to distinguish pneumonia cases from normal lung X-ray images. I fine-tuned a pre-trained ResNet-18 convolutional neural network to classify X-rays into two categories: normal lungs and those affected by pneumonia. Leveraging the pre-trained weights of ResNet-18 allowed me to create an accurate classifier efficiently, reducing the resources and time needed for training.

## The Data

<img src="x-rays_sample.png" align="center"/>
&nbsp

The dataset consisted of 300 chest X-rays for training and 100 for testing, evenly divided between NORMAL and PNEUMONIA categories. The images had been preprocessed and organized into train and test folders, with data loaders ready for use with PyTorch. This project highlights my ability to implement advanced deep learning techniques in a healthcare context, showcasing how AI improves diagnostic processes.ses.

In [65]:
# # Make sure to run this cell to use torchmetrics.
!pip install torch torchvision torchmetrics

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [66]:
# Import required libraries
# Data loading
import random
import numpy as np
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Train model
import torch
from torchvision import models
import torch.nn as nn
import torch.optim as optim

# Evaluate model
from torchmetrics import Accuracy, F1Score

# Set random seeds for reproducibility
torch.manual_seed(101010)
np.random.seed(101010)
random.seed(101010)

In [67]:
import os
import zipfile

# Unzip the data folder
if not os.path.exists('data/chestxrays'):
    with zipfile.ZipFile('data/chestxrays.zip', 'r') as zip_ref:
        zip_ref.extractall('data')

In [68]:
# Normalize the X-rays using transforms.
# standard deviations of the three color channels, (R,G,B), from the original ResNet-18 training dataset.
transform_mean = [0.485, 0.456, 0.406]
transform_std =[0.229, 0.224, 0.225]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(mean=transform_mean, std=transform_std)])

# Apply the image transforms
train_dataset = ImageFolder('data/chestxrays/train', transform=transform)
test_dataset = ImageFolder('data/chestxrays/test', transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset) // 2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

In [69]:
# Instantiate the model
# Load the pre-trained ResNet-18 model
resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

In [70]:
# Modify the model
# Freeze the parameters of the model
for param in resnet18.parameters():
    param.requires_grad = False

# Modify the final layer for binary classification
resnet18.fc = nn.Linear(resnet18.fc.in_features, 1)

In [71]:
# Define the training loop
# Model training/fine-tuning loop
def train(model, train_loader, criterion, optimizer, num_epochs):
    
    # Train the model for the specified number of epochs
    for epoch in range(num_epochs):
        # Set the model to train mode
        model.train()

        # Initialize the running loss and accuracy
        running_loss = 0.0
        running_accuracy = 0

        # Iterate over the batches of the train loader
        for inputs, labels in train_loader:

            # Zero the optimizer gradients
            optimizer.zero_grad()
            
            # Ensure labels have the same dimensions as outputs
            labels = labels.float().unsqueeze(1)

            # Forward pass
            outputs = model(inputs)
            preds = torch.sigmoid(outputs) > 0.5 # Binary classification
            loss = criterion(outputs, labels)

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()

            # Update the running loss and accuracy
            running_loss += loss.item() * inputs.size(0)
            running_accuracy += torch.sum(preds == labels.data)

        # Calculate the train loss and accuracy for the current epoch
        train_loss = running_loss / len(train_dataset)
        train_acc = running_accuracy.double() / len(train_dataset)

        # Print the epoch results
        print('Epoch [{}/{}], train loss: {:.4f}, train acc: {:.4f}'
              .format(epoch+1, num_epochs, train_loss, train_acc))


In [72]:
#Fine-tune the model       
# Set the model to ResNet-18
model = resnet18

# Fine-tune the ResNet-18 model for 3 epochs using the train_loader
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
train(model, train_loader, criterion, optimizer, num_epochs=3)

Epoch [1/3], train loss: 1.3915, train acc: 0.4567
Epoch [2/3], train loss: 0.8973, train acc: 0.4633
Epoch [3/3], train loss: 0.9199, train acc: 0.5033


In [73]:
# Evaluate the model

# Set model to evaluation mode
model = resnet18
model.eval()

# Initialize metrics for accuracy and F1 score
accuracy_metric = Accuracy(task="binary")
f1_metric = F1Score(task="binary")

# Create lists store all predictions and labels
all_preds = []
all_labels = []

# Disable gradient calculation for evaluation
with torch.no_grad():
  for inputs, labels in test_loader:
    # Forward pass
    outputs = model(inputs)
    preds = torch.sigmoid(outputs).round()  # Round to 0 or 1

    # Extend the lists with predictions and labels
    all_preds.extend(preds.tolist())
    all_labels.extend(labels.unsqueeze(1).tolist())

  # Convert lists back to tensors
  all_preds = torch.tensor(all_preds)
  all_labels = torch.tensor(all_labels)

  # Calculate accuracy and F1 score
  test_accuracy = accuracy_metric(all_preds, all_labels).item()
  test_f1_score = f1_metric(all_preds, all_labels).item()
  print(f"\nTest accuracy: {test_acc:.3f}\nTest F1-score: {test_f1:.3f}")


Test accuracy: 0.580
Test F1-score: 0.704
