In [1]:
import torch 

torch.cuda.is_available()

True

In [2]:
import pandas as pd

dataset = pd.read_csv('data/labeled_data.csv')

In [3]:
allowed_values = ['AP', 'PA', 'LATERAL', 'LL']

dataset = dataset[dataset['ViewPosition'].isin(allowed_values)]

In [4]:
len(dataset)

361316

In [5]:
from sklearn.model_selection import train_test_split

train_dataset, test_dataset = train_test_split(dataset, test_size=.025, shuffle=True, random_state=42)
test_dataset, validation_dataset = train_test_split(test_dataset, test_size=.5, shuffle=True,random_state=42)

print(f'Len Train Dataset: {len(train_dataset)}')
print(f'Len Validation Dataset: {len(validation_dataset)}')
print(f'Len Test Dataset: {len(test_dataset)}')

Len Train Dataset: 352283
Len Validation Dataset: 4517
Len Test Dataset: 4516


In [6]:
from torch.utils.data import Dataset
from PIL import Image

class MedicalDataset(Dataset):
    def __init__(self, csv_dataset, transform=None):
        super(MedicalDataset, self).__init__()
        self.data = csv_dataset
        self.transform = transform
        
        self.label_map = {
            'AP': 0,
            'PA': 1,
            'LATERAL': 2,
            'LL': 3
        }
        
        self.data['ViewPosition'] = self.data['ViewPosition'].map(self.label_map)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        image_path = row['Image_Path']
        image = Image.open(image_path).convert('L')
        
        if self.transform:
            image = self.transform(image)
            
        image = image.to('cuda')
        
        label = torch.tensor(row['ViewPosition'], dtype=torch.long).to('cuda')
        
        return image, label

In [7]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

train_dataset = MedicalDataset(train_dataset, transform)
validation_dataset = MedicalDataset(validation_dataset, transform)
test_dataset = MedicalDataset(test_dataset, transform)

In [8]:
from torch.utils.data import DataLoader

train_dataset = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_dataset = DataLoader(validation_dataset, batch_size=32, shuffle=True)
test_dataset = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [9]:
from torch import nn

class MedicalCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(MedicalCNN, self).__init__()
        
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv_3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        
        self.fully_connected_1 = nn.Linear(128 * 32 * 32, 512)
        self.fully_connected_2 = nn.Linear(512, num_classes)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=.5)
        
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv_1(x)))
        x = self.pool(self.relu(self.conv_2(x)))
        x = self.pool(self.relu(self.conv_3(x)))
        
        x = x.view(-1, 128 * 32 * 32)
        x = self.relu(self.fully_connected_1(x))
        x = self.dropout(x)
        x = self.fully_connected_2(x)
        
        return x

In [10]:
from torch import optim

model = MedicalCNN(num_classes=4).to('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
from tqdm import tqdm

num_epochs = 1

for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    batch_count = 0

    for images, labels in tqdm(train_dataset, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to('cuda'), labels.to('cuda')
        model.train()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        batch_count += 1
        if batch_count % 313 == 0:
            model.eval()
            valid_loss = 0.0
            valid_correct = 0
            valid_total = 0
            
            with torch.no_grad():
                for valid_images, valid_labels in validation_dataset:
                    valid_images, valid_labels = valid_images.to('cuda'), valid_labels.to('cuda')
                    valid_outputs = model(valid_images)
                    valid_loss_item = criterion(valid_outputs, valid_labels)
                    
                    valid_loss += valid_loss_item.item()
                    _, valid_predicted = torch.max(valid_outputs, 1)
                    valid_total += valid_labels.size(0)
                    valid_correct += (valid_predicted == valid_labels).sum().item()

            valid_loss = valid_loss / len(validation_dataset)
            valid_accuracy = 100 * valid_correct / valid_total
            print(f"Batch {batch_count}: Train Loss: {running_loss / (batch_count)} | "
                  f"Train Accuracy: {100 * correct / total:.2f}% | "
                  f"Validation Loss: {valid_loss:.4f} | Validation Accuracy: {valid_accuracy:.2f}%")

    epoch_loss = running_loss / len(train_dataset)
    epoch_accuracy = 100 * correct / total
    print(f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%")

Epoch 1/1:   3%|▎         | 313/11009 [33:31<547:54:09, 184.41s/it]

Batch 313: Train Loss: 0.38581282192002087 | Train Accuracy: 85.16% | Validation Loss: 0.1714 | Validation Accuracy: 94.64%


Epoch 1/1:   6%|▌         | 626/11009 [58:43<179:46:15, 62.33s/it] 

Batch 626: Train Loss: 0.2764195495109541 | Train Accuracy: 89.99% | Validation Loss: 0.1244 | Validation Accuracy: 96.46%


Epoch 1/1:   9%|▊         | 939/11009 [1:24:53<259:38:47, 92.82s/it]

Batch 939: Train Loss: 0.23134616775293385 | Train Accuracy: 91.91% | Validation Loss: 0.1059 | Validation Accuracy: 96.97%


Epoch 1/1:  11%|█▏        | 1252/11009 [1:48:20<162:12:24, 59.85s/it]

Batch 1252: Train Loss: 0.20653875982656647 | Train Accuracy: 92.99% | Validation Loss: 0.1146 | Validation Accuracy: 96.39%


Epoch 1/1:  14%|█▍        | 1565/11009 [2:11:15<156:41:33, 59.73s/it]

Batch 1565: Train Loss: 0.1893404046726672 | Train Accuracy: 93.68% | Validation Loss: 0.0889 | Validation Accuracy: 97.25%


Epoch 1/1:  17%|█▋        | 1878/11009 [2:34:02<150:55:40, 59.51s/it]

Batch 1878: Train Loss: 0.1749231472150973 | Train Accuracy: 94.20% | Validation Loss: 0.0930 | Validation Accuracy: 97.43%


Epoch 1/1:  20%|█▉        | 2191/11009 [2:56:34<145:34:47, 59.43s/it]

Batch 2191: Train Loss: 0.16385414676667653 | Train Accuracy: 94.63% | Validation Loss: 0.0950 | Validation Accuracy: 96.99%


Epoch 1/1:  23%|██▎       | 2504/11009 [3:18:55<140:50:47, 59.62s/it]

Batch 2504: Train Loss: 0.1554080235191432 | Train Accuracy: 94.97% | Validation Loss: 0.0894 | Validation Accuracy: 97.72%


Epoch 1/1:  26%|██▌       | 2817/11009 [3:41:14<135:13:36, 59.43s/it]

Batch 2817: Train Loss: 0.14767814052737865 | Train Accuracy: 95.28% | Validation Loss: 0.0764 | Validation Accuracy: 98.05%


Epoch 1/1:  28%|██▊       | 3130/11009 [4:03:33<130:21:23, 59.56s/it]

Batch 3130: Train Loss: 0.14351599771758547 | Train Accuracy: 95.49% | Validation Loss: 0.0800 | Validation Accuracy: 97.79%


Epoch 1/1:  31%|███▏      | 3443/11009 [4:25:45<125:22:36, 59.66s/it]

Batch 3443: Train Loss: 0.13820993937712245 | Train Accuracy: 95.72% | Validation Loss: 0.0771 | Validation Accuracy: 98.10%


Epoch 1/1:  34%|███▍      | 3756/11009 [4:48:21<119:49:39, 59.48s/it]

Batch 3756: Train Loss: 0.1333439629452729 | Train Accuracy: 95.90% | Validation Loss: 0.1185 | Validation Accuracy: 96.46%


Epoch 1/1:  37%|███▋      | 4069/11009 [5:10:51<118:40:15, 61.56s/it]

Batch 4069: Train Loss: 0.12961351523643558 | Train Accuracy: 96.05% | Validation Loss: 0.0694 | Validation Accuracy: 98.34%


Epoch 1/1:  40%|███▉      | 4382/11009 [5:33:26<109:40:53, 59.58s/it]

Batch 4382: Train Loss: 0.125534631989318 | Train Accuracy: 96.19% | Validation Loss: 0.0766 | Validation Accuracy: 98.21%


Epoch 1/1:  43%|████▎     | 4695/11009 [5:55:55<103:26:33, 58.98s/it]

Batch 4695: Train Loss: 0.12214623533749182 | Train Accuracy: 96.30% | Validation Loss: 0.0756 | Validation Accuracy: 98.12%


Epoch 1/1:  45%|████▌     | 5008/11009 [6:18:25<102:56:50, 61.76s/it]

Batch 5008: Train Loss: 0.1197679149125109 | Train Accuracy: 96.40% | Validation Loss: 0.0741 | Validation Accuracy: 98.16%


Epoch 1/1:  48%|████▊     | 5321/11009 [6:41:29<96:49:47, 61.28s/it] 

Batch 5321: Train Loss: 0.11724199302479386 | Train Accuracy: 96.49% | Validation Loss: 0.0678 | Validation Accuracy: 98.41%


Epoch 1/1:  51%|█████     | 5634/11009 [7:04:45<90:40:09, 60.73s/it]

Batch 5634: Train Loss: 0.11526138881287701 | Train Accuracy: 96.57% | Validation Loss: 0.0725 | Validation Accuracy: 98.14%


Epoch 1/1:  54%|█████▍    | 5947/11009 [7:28:07<86:17:13, 61.37s/it]

Batch 5947: Train Loss: 0.11260284856598982 | Train Accuracy: 96.67% | Validation Loss: 0.1018 | Validation Accuracy: 97.45%


Epoch 1/1:  57%|█████▋    | 6260/11009 [7:53:30<81:34:47, 61.84s/it]

Batch 6260: Train Loss: 0.11097723173669177 | Train Accuracy: 96.73% | Validation Loss: 0.0749 | Validation Accuracy: 98.25%


Epoch 1/1:  60%|█████▉    | 6573/11009 [8:18:58<76:06:45, 61.77s/it]

Batch 6573: Train Loss: 0.10971809773783368 | Train Accuracy: 96.79% | Validation Loss: 0.0627 | Validation Accuracy: 98.58%


Epoch 1/1:  63%|██████▎   | 6886/11009 [8:43:39<70:55:54, 61.93s/it]

Batch 6886: Train Loss: 0.10782500944685643 | Train Accuracy: 96.86% | Validation Loss: 0.0653 | Validation Accuracy: 98.36%


Epoch 1/1:  65%|██████▌   | 7199/11009 [9:09:06<65:19:45, 61.73s/it]

Batch 7199: Train Loss: 0.10615097154728118 | Train Accuracy: 96.93% | Validation Loss: 0.0797 | Validation Accuracy: 97.70%


Epoch 1/1:  68%|██████▊   | 7512/11009 [9:35:10<60:45:06, 62.54s/it]

Batch 7512: Train Loss: 0.10453411091851575 | Train Accuracy: 96.99% | Validation Loss: 0.0687 | Validation Accuracy: 98.36%


Epoch 1/1:  71%|███████   | 7825/11009 [9:58:15<51:12:17, 57.89s/it]

Batch 7825: Train Loss: 0.10290596592509316 | Train Accuracy: 97.04% | Validation Loss: 0.0634 | Validation Accuracy: 98.41%


Epoch 1/1:  74%|███████▍  | 8138/11009 [10:22:12<46:17:59, 58.06s/it]

Batch 8138: Train Loss: 0.10174991923902958 | Train Accuracy: 97.09% | Validation Loss: 0.0749 | Validation Accuracy: 98.25%


Epoch 1/1:  77%|███████▋  | 8451/11009 [10:46:49<41:04:08, 57.80s/it]

Batch 8451: Train Loss: 0.10042150706912664 | Train Accuracy: 97.14% | Validation Loss: 0.0607 | Validation Accuracy: 98.65%


Epoch 1/1:  80%|███████▉  | 8764/11009 [11:11:08<36:05:54, 57.89s/it]

Batch 8764: Train Loss: 0.0992081664364331 | Train Accuracy: 97.18% | Validation Loss: 0.0636 | Validation Accuracy: 98.52%


Epoch 1/1:  82%|████████▏ | 9077/11009 [11:34:22<31:02:35, 57.84s/it]

Batch 9077: Train Loss: 0.09786303739066397 | Train Accuracy: 97.23% | Validation Loss: 0.0655 | Validation Accuracy: 98.49%


Epoch 1/1:  85%|████████▌ | 9390/11009 [11:58:14<26:07:02, 58.07s/it]

Batch 9390: Train Loss: 0.09714852670427587 | Train Accuracy: 97.27% | Validation Loss: 0.0679 | Validation Accuracy: 98.16%


Epoch 1/1:  88%|████████▊ | 9703/11009 [12:21:09<21:02:21, 58.00s/it]

Batch 9703: Train Loss: 0.09642075731507457 | Train Accuracy: 97.29% | Validation Loss: 0.0622 | Validation Accuracy: 98.52%


Epoch 1/1:  91%|█████████ | 10016/11009 [12:46:17<15:57:49, 57.88s/it]

Batch 10016: Train Loss: 0.09523248128138098 | Train Accuracy: 97.34% | Validation Loss: 0.0604 | Validation Accuracy: 98.61%


Epoch 1/1:  94%|█████████▍| 10329/11009 [13:10:42<10:56:23, 57.92s/it]

Batch 10329: Train Loss: 0.09445463369090683 | Train Accuracy: 97.37% | Validation Loss: 0.0625 | Validation Accuracy: 98.58%


Epoch 1/1:  97%|█████████▋| 10642/11009 [13:35:56<5:55:22, 58.10s/it] 

Batch 10642: Train Loss: 0.09363973015467937 | Train Accuracy: 97.40% | Validation Loss: 0.0576 | Validation Accuracy: 98.56%


Epoch 1/1: 100%|█████████▉| 10955/11009 [14:00:29<52:11, 57.99s/it]  

Batch 10955: Train Loss: 0.09240689024316885 | Train Accuracy: 97.44% | Validation Loss: 0.0631 | Validation Accuracy: 98.67%


Epoch 1/1: 100%|██████████| 11009/11009 [14:03:53<00:00,  4.60s/it]

Train Loss: 0.0923, Train Accuracy: 97.45%





In [12]:
torch.save(model.state_dict(), 'model_1.pth')

In [13]:
model.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_dataset:
        images, labels = images.to('cuda'), labels.to('cuda')
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = test_loss / len(test_dataset)
test_accuracy = 100 * correct / total
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")