In [2]:
import numpy as np
import pandas as pd
from PIL import Image
import os

import torch
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50

In [26]:
# Create a new directory if it does not exist
def create_folder(folder):
    if not os.path.exists(f'{folder}'):
        os.makedirs(folder)
        print(f'New directory {folder} is created!')
    return folder


class_folder = 'groups/CS156b'
group_folder = f'{class_folder}/2024/Edgemax'
preprocess_folder = f'{group_folder}/preprocess'

pathologies = ['No Finding']
orientations = ['Frontal', 'Lateral']


train_file_path = f''
train_csv_path = f'train.csv'
train_data = [[[] for _ in range(len(orientations))] for _ in range(len(pathologies))]
train_labels = [[[] for _ in range(len(orientations))] for _ in range(len(pathologies))]
csv_train = pd.read_csv(train_csv_path, sep=',')[:10000]
transform = transforms.Compose([
    transforms.Resize(320),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


for index, row in csv_train.iterrows():
    if index % max(1, len(csv_train) // 1000) == 0:
        print(f'{round(index / csv_train.index[-1] * 100, 2)}% completed')
    image_path = f'{row['Path']}'
    if os.path.exists(image_path):
        image = Image.open(image_path)
        image = image.convert('RGB')
        image_tensor = transform(image)
        for i, pathology in enumerate(pathologies):
            label_tensor = torch.tensor([row[pathology]])
            if (not np.isnan(row[pathology])):
                j = (row['Frontal/Lateral'] != 'Frontal')
                train_data[i][j].append(image_tensor)
                train_labels[i][j].append(label_tensor)
                

for i, pathology in enumerate(pathologies):
    for j, orientation in enumerate(orientations):
        print(f'Train data for: {pathology.replace(' ', '_')} {orientation}')
        sub_folder =  f'{preprocess_folder}/{pathology.replace(' ', '_')}/{orientation}'
        create_folder(sub_folder)
        print(f'Saving {len(train_data[i][j])} training images.')
        torch.save(train_data[i][j], f'{sub_folder}/train_data_{pathology.replace(' ', '_')}_{orientation}.pt')
        print(f'Saving {len(train_labels[i][j])} training labels.')
        torch.save(train_labels[i][j], f'{sub_folder}/train_labels_{pathology.replace(' ', '_')}_{orientation}.pt')
print('Preprocessing completed.')

0.0% completed
0.1% completed
0.2% completed
0.3% completed
0.4% completed
0.5% completed
0.6% completed
0.7% completed
0.8% completed
0.9% completed
1.0% completed
1.1% completed
1.2% completed
1.3% completed
1.4% completed
1.5% completed
1.6% completed
1.7% completed
1.8% completed
1.9% completed
2.0% completed
2.1% completed
2.2% completed
2.3% completed
2.4% completed
2.5% completed
2.6% completed
2.7% completed
2.8% completed
2.9% completed
3.0% completed
3.1% completed
3.2% completed
3.3% completed
3.4% completed
3.5% completed
3.6% completed
3.7% completed
3.8% completed
3.9% completed
4.0% completed
4.1% completed
4.2% completed
4.3% completed
4.4% completed
4.5% completed
4.6% completed
4.7% completed
4.8% completed
4.9% completed
5.0% completed
5.1% completed
5.2% completed
5.3% completed
5.4% completed
5.5% completed
5.6% completed
5.7% completed
5.8% completed
5.9% completed
6.0% completed
6.1% completed
6.2% completed
6.3% completed
6.4% completed
6.5% completed
6.6% compl

In [33]:
class XRayDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        return self.features[index], self.labels[index]
    

pathology = 'No Finding'
orientation = 'Frontal'

class_folder = 'groups/CS156b'
preprocess_folder = f'{class_folder}/2024/Edgemax/preprocess'
sub_folder =  f'{preprocess_folder}/{pathology.replace(' ', '_')}/{orientation}'
train_data_path = f'{sub_folder}/train_data_{pathology.replace(' ', '_')}_{orientation}.pt'
train_labels_path = f'{sub_folder}/train_labels_{pathology.replace(' ', '_')}_{orientation}.pt'
model_path = f'{sub_folder}/model_{pathology.replace(' ', '_')}_{orientation}.pt'

train_proportion = 0.8
batch_size = 8
learning_rate = 0.001
momentum = 0.9
num_epochs = 10

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device type: {device}")


data = torch.load(train_data_path)
labels = torch.load(train_labels_path)
train_data = data[:int(train_proportion*len(data))]
train_labels = labels[:int(train_proportion*len(labels))]
val_data = data[int(train_proportion*len(data)):]
val_labels = labels[int(train_proportion*len(labels)):]

train_dataset = XRayDataset(train_data, train_labels)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = XRayDataset(val_data, val_labels)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)


model = resnet50(weights='ResNet50_Weights.DEFAULT')
model.fc = nn.Sequential(nn.Linear(2048, 256),
                        nn.ReLU(), 
                        nn.Linear(256, 1),
                        nn.Tanh())
model = nn.DataParallel(model)
model.to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)


for epoch in range(num_epochs):
    idx = 0
    for inputs, labels in train_dataloader:
        # labels = labels.unsqueeze(1)
        inputs, labels = inputs.to(device), labels.to(device)
        if idx % max(1, len(train_dataloader) // 50) == 0:
            print(".", end ="")
        idx += 1
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')

print(f'Finished Training')
torch.save(model.state_dict(), model_path)

Device type: cpu
.......Epoch 1/10, Loss: 0.6018
.......Epoch 2/10, Loss: 0.2593
.......Epoch 3/10, Loss: 0.4292
.......Epoch 4/10, Loss: 0.3933
.

KeyboardInterrupt: 