### IMPORTS

In [None]:
# IMPORTS=
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import cv2
import pydicom as dicom
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader

# MY DATASET=
from src import ChestXrayDataset as CXD
from src import ChestXrayDatasetV2 as CXD2
from src.Preprocessing import preprocess_data
from src.Preprocessing import class_ids_and_names

### DATA PREPROCESSING

In [None]:
# SETTINGS =
data_dir = "src/data/input/256x256/"                            # MAIN DIRECTORY CONTAINING THE DATA
train_df = pd.read_csv(data_dir + "train.csv")                  # TRAINING DATA
train_df_sizes = pd.read_csv(data_dir + "train_meta.csv")       # TRAINING DATA SIZES

# ADV. SETTINGS =
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu")   # DO NOT TOUCH

# PREPROCESSING:
train_df = preprocess_data(data_dir, train_df, train_df_sizes)  # DATA PREPROCESSING (+ SAVE -> train_clean.csv)
class_ids, class_names = class_ids_and_names(train_df)

### FIRST APPROACH: RESNET18 FULL IMPLEMENTATION

In [None]:
# TRANSFORMATION PIPELINE:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize the image
])

# DATASET:
dataset = CXD.ChestXrayDataset(csv_file="train_clean.csv", data_dir=data_dir, transform=transform)

# TRAIN/VALIDATION SPLIT:
ratio = 0.8
train_dataset, val_dataset = dataset.split(ratio)

# DATALOADERS:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
print("Number of train images: ", len(train_dataset))
print("Number of validation images: ", len(val_dataset))
print("Number of batches: ", len(train_loader))
for i, (images, labels) in enumerate(train_loader):
    print("Images shape: ", images.shape)
    print("Labels shape: ", labels.shape)
    break

In [None]:
# Define the model:
model = models.resnet18(pretrained=True)
num_classes = 15
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)

# Define the loss function and the optimizer:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train the model:
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0.0
    train_correct = 0
    model.train()   # Set the model to training mode
    print("Epoch: ", epoch)
    for i, (images, labels) in enumerate(train_loader):
        if i % 10 == 0:
            print("Batch: "+str(i)+" began.")
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()     # Weight update
        optimizer.step()    # Gradient update
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs.data, 1)
        train_correct += torch.sum(preds == labels.data).sum().item()
        train_loss += loss.item() * images.size(0)
    train_acc = train_correct / len(train_dataset)
    train_loss = train_loss / len(train_dataset)
    print("Epoch: {}/{}...".format(epoch + 1, num_epochs),
          "Training Loss: {:.4f}...".format(train_loss),
          "Training Accuracy: {:.4f}".format(train_acc))

In [None]:
# Test the model on the validation set:
model.eval()
val_loss = 0.0
val_correct = 0
for i, (images, labels) in enumerate(val_loader):
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    val_loss += loss.item() * images.size(0)
    _, preds = torch.max(outputs.data, 1)
    val_correct += torch.sum(preds == labels.data).sum().item()
val_acc = val_correct / len(val_dataset)
val_loss = val_loss / len(val_dataset)
print("Validation Loss: {:.4f}...".format(val_loss),
      "Validation Accuracy: {:.4f}".format(val_acc))

In [None]:
# Save the model:
# torch.save(model.state_dict(), "src/data/output/model.pth")

In [None]:
# Load the model:
model.load_state_dict(torch.load("src/data/output/resnet18_e10.pth"))
# Then, re-test the model on the validation set.

### NEW MODEL: YOLOv5

In [None]:
# https://www.kaggle.com/ultralytics/yolov5
# !git clone https://github.com/ultralytics/yolov5  # clone repo
# %cd yolov5
# !pip install -r requirements.txt  # install dependencies
# cmd = "!python {yolo_dir}train.py --img 256 --batch 32 --epochs 2 --data {yaml_path} --weights {model} --cache"

In [None]:
ratio = 0.8
val_df = train_df.sample(frac=1-ratio, random_state=42)
train_df = train_df.drop(val_df.index)

In [None]:
# YOLO STUFF:
yolostuff_dir = "src/yolostuff/"
yaml_path = yolostuff_dir + "datasets/vinbigdata/vinbigdata.yaml"
model = yolostuff_dir + "yolov5/models/yolov5s.pt"

# Save all the images names in a .txt:
txt_file = ""
for row in train_df["image_id"]:
    txt_file += "./images/" + row + ".png\n"
txt_file_path = yolostuff_dir + "datasets/vinbigdata/train.txt"
txt_file_opened = open(txt_file_path, "w")
txt_file_opened.write(txt_file)
txt_file_opened.close()

# Save all the images names in a .txt:
txt_file = ""
for row in val_df["image_id"]:
    txt_file += "./images/" + row + ".png\n"
txt_file_path = yolostuff_dir + "datasets/vinbigdata/val.txt"
txt_file_opened = open(txt_file_path, "w")
txt_file_opened.write(txt_file)
txt_file_opened.close()

# Save all the images names in a .txt:
test_df = pd.read_csv("src/data/input/256x256/sample_submission.csv")
# TODO.

In [None]:
!python {yolo_dir}train.py --img 256 --batch 32 --epochs 1 --data {yaml_path} --weights {model} --cache

In [None]:
!python {yolo_dir}detect.py --weights 'src/yolostuff/yolov5/runs/train/exp19/weights/best.pt' --img 256 --conf 0.15 --iou 0.5 --source 'src/data/input/256x256/test' --exist-ok

## VISUALIZATION

In [None]:
# Debug: Plot one image, its labels and its bounding boxes:
batch1 = next(iter(train_loader))

In [None]:
index = 10
img = batch1[0][index].permute(1, 2, 0)
# In this format, the image is in RGB, but the values are between -1 and 1.
# We need to convert it to 0-255:
img = (img + 1) / 2
plt.imshow(img)
label = batch1[1][index].item()
bbox = batch1[2][index]
# Get first element of bbox, and convert it to a int:
print(bbox)
print(label)
# Debug: Plot the bounding boxes:
for i in range(0, len(bbox), 4):
    xmin = bbox[i].item()
    ymin = bbox[i+1].item()
    width = bbox[i+2].item()
    height = bbox[i+3].item()
    rect = plt.Rectangle((xmin, ymin), width, height, fill=False, color='red')
    plt.gca().add_patch(rect)
    plt.show()