### IMPORTS

In [None]:
# IMPORTS=
import pandas as pd

# PYTORCH=
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader

# MY DATASET=
from src import ChestXrayDataset as CXD

### DATA PREPROCESSING

In [None]:
data_dir = "src/data/input/256x256/"
train_df = pd.read_csv(data_dir + "train.csv")

In [None]:
# Join rows with on image_id, and keep the class_id that is the most frequent:
df_ = train_df.groupby("image_id")["class_id"].apply(lambda x: x.value_counts().index[0]).reset_index()
# Remove the class_id column from the original dataframe:
train_df.drop("class_id", axis=1, inplace=True)
# In the train_df, merge the rows with the same image_id, and keep the first row:
train_df = train_df.groupby("image_id").first().reset_index()
# Merge the original dataframe with first jointure:
merge_df = pd.merge(train_df, df_, on=["image_id"])
# Save in a .csv file
merge_df.to_csv(data_dir + "train_clean.csv")

### PYTORCH CHEST-XRAY DATASET

In [None]:
# Create the dataset:
data_dir = "src/data/input/256x256/"
train_df = pd.read_csv(data_dir + "train_clean.csv")

# Transformation pipeline:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize the image
])

In [None]:
# Create the dataset:
dataset = CXD.ChestXrayDataset(csv_file="train_clean.csv", data_dir=data_dir, transform=transform)

### TRAIN/VALIDATION SPLIT

In [None]:
# Split the dataset into train and validation:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create the dataloaders:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

print("Number of train images: ", len(train_dataset))
print("Number of validation images: ", len(val_dataset))
print("Number of batches: ", len(train_loader))
for i, (images, labels) in enumerate(train_loader):
    print("Images shape: ", images.shape)
    print("Labels shape: ", labels.shape)
    break

### MODEL TRAINING

In [None]:
# Use GPU if available:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu")

# Define the model:
model = models.resnet18(pretrained=True)
num_classes = 15
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

# Define the loss function and the optimizer:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0.0
    train_correct = 0
    model.train()   # Set the model to training mode
    print("Epoch: ", epoch)
    for i, (images, labels) in enumerate(train_loader):
        if i % 10 == 0:
            print("Batch: "+str(i)+" began.")
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()     # Weight update
        optimizer.step()    # Gradient update
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs.data, 1)
        train_correct += torch.sum(preds == labels.data).sum().item()
        train_loss += loss.item() * images.size(0)
    train_acc = train_correct / len(train_dataset)
    train_loss = train_loss / len(train_dataset)
    print("Epoch: {}/{}...".format(epoch + 1, num_epochs),
          "Training Loss: {:.4f}...".format(train_loss),
          "Training Accuracy: {:.4f}".format(train_acc))

In [None]:
# Test the model on the validation set:
model.eval()
val_loss = 0.0
val_correct = 0
for i, (images, labels) in enumerate(val_loader):
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    val_loss += loss.item() * images.size(0)
    _, preds = torch.max(outputs.data, 1)
    val_correct += torch.sum(preds == labels.data).sum().item()
val_acc = val_correct / len(val_dataset)
val_loss = val_loss / len(val_dataset)
print("Validation Loss: {:.4f}...".format(val_loss),
      "Validation Accuracy: {:.4f}".format(val_acc))

In [None]:
# Save the model:
# torch.save(model.state_dict(), "src/data/output/model.pth")