In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

In [2]:
!nvidia-smi # Checking nvidia cores and useage

Thu Jan 30 01:00:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A40                     On  |   00000000:01:00.0 Off |                    0 |
|  0%   37C    P0             75W /  300W |   40839MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     On  |   00

In [3]:
############################################################
# 1) DATA PREPARATION (Both 1-channel and 3-channel)
############################################################

# dataset directory structure
data_dir = "/nobackup/kumar13/data/ChestXRay2017/chest_xray"
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')

# Transform for 1-channel (used by CNN and CNN-LSTM )
transform_1ch = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Transform for 3-channel (used by AlexNet, ResNet, and Transformer models)
transform_3ch = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])


In [4]:
# 1-channel datasets/loaders
train_data_1ch = datasets.ImageFolder(train_dir, transform=transform_1ch)
test_data_1ch = datasets.ImageFolder(test_dir, transform=transform_1ch)
train_loader_1ch = DataLoader(train_data_1ch, batch_size=32, shuffle=True)
test_loader_1ch = DataLoader(test_data_1ch, batch_size=32, shuffle=False)

# 3-channel datasets/loaders
train_data_3ch = datasets.ImageFolder(train_dir, transform=transform_3ch)
test_data_3ch = datasets.ImageFolder(test_dir, transform=transform_3ch)
train_loader_3ch = DataLoader(train_data_3ch, batch_size=32, shuffle=True)
test_loader_3ch = DataLoader(test_data_3ch, batch_size=32, shuffle=False)

In [5]:
# Multi-GPU configuration
# 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [6]:
############################################################
# 2) SIMPLE CNN 
############################################################

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 56 * 56, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

cnn_model = SimpleCNN()
cnn_model = nn.DataParallel(cnn_model, device_ids=[0,1,2,3]).to(device)

criterion_cnn = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=0.001)

In [7]:
def train_model(model, loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {running_loss/len(loader):.4f}")

In [8]:
def test_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

In [9]:
############################################################
# 3) CNN-LSTM HYBRID MODEL
############################################################

class CNN_LSTM(nn.Module):
    def __init__(self):
        super(CNN_LSTM, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.lstm = nn.LSTM(128, 128, batch_first=True)
        self.fc = nn.Linear(128, 2)

    def forward(self, x):
        batch_size, timesteps, C, H, W = x.size()
        x = x.view(batch_size * timesteps, C, H, W)
        cnn_out = self.cnn(x)
        cnn_out = self.global_pool(cnn_out)
        cnn_out = cnn_out.view(batch_size, timesteps, -1)
        lstm_out, _ = self.lstm(cnn_out)
        return self.fc(lstm_out[:, -1, :])

cnn_lstm_model = CNN_LSTM()
cnn_lstm_model = nn.DataParallel(cnn_lstm_model, device_ids=[0,1,2,3]).to(device)

criterion_cnn_lstm = nn.CrossEntropyLoss()
optimizer_cnn_lstm = optim.Adam(cnn_lstm_model.parameters(), lr=0.001)


In [10]:
def train_model_cnn_lstm(model, loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in loader:
            # The usr code adds an extra time dimension
            images = images.unsqueeze(1).to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {running_loss/len(loader):.4f}")

In [11]:
def test_accuracy_cnn_lstm(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.unsqueeze(1).to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

In [12]:
############################################################
# 4) ALEXNET
############################################################

from torchvision.models import alexnet

alexnet_model = alexnet(weights=None)
alexnet_model.classifier[6] = nn.Linear(4096, 2)
alexnet_model = nn.DataParallel(alexnet_model, device_ids=[0,1,2,3]).to(device)

criterion_alex = nn.CrossEntropyLoss()
optimizer_alex = optim.Adam(alexnet_model.parameters(), lr=0.001)

In [13]:
############################################################
# 5) RESNET (for example ResNet18)
############################################################

from torchvision.models import resnet18

resnet_model = resnet18(weights=None)
num_ftrs = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_ftrs, 2)
resnet_model = nn.DataParallel(resnet_model, device_ids=[0,1,2,3]).to(device)

criterion_resnet = nn.CrossEntropyLoss()
optimizer_resnet = optim.Adam(resnet_model.parameters(), lr=0.001)

In [14]:
############################################################
# 6) VISION TRANSFORMER (Using timm)
############################################################

# 
try:
    import timm
    vit_model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=2)
    vit_model = nn.DataParallel(vit_model, device_ids=[0,1,2,3]).to(device)
    criterion_vit = nn.CrossEntropyLoss()
    optimizer_vit = optim.Adam(vit_model.parameters(), lr=0.001)
    has_timm = True
except ImportError:
    print("Warning: timm is not installed. Vision Transformer code will be skipped.")
    vit_model = None
    has_timm = False

In [15]:
############################################################
# 7) TRAIN AND TEST ALL MODELS
############################################################

print("\n==================== TRAINING CNN ====================")
train_model(cnn_model, train_loader_1ch, criterion_cnn, optimizer_cnn, num_epochs=5)
cnn_train_acc = test_accuracy(cnn_model, train_loader_1ch)
cnn_test_acc = test_accuracy(cnn_model, test_loader_1ch)


Epoch [1/5] - Loss: 0.2524
Epoch [2/5] - Loss: 0.0718
Epoch [3/5] - Loss: 0.0548
Epoch [4/5] - Loss: 0.0401
Epoch [5/5] - Loss: 0.0264


In [16]:
print("\n==================== TRAINING CNN-LSTM ====================")
train_model_cnn_lstm(cnn_lstm_model, train_loader_1ch, criterion_cnn_lstm, optimizer_cnn_lstm, num_epochs=5)
cnn_lstm_train_acc = test_accuracy_cnn_lstm(cnn_lstm_model, train_loader_1ch)
cnn_lstm_test_acc = test_accuracy_cnn_lstm(cnn_lstm_model, test_loader_1ch)


Epoch [1/5] - Loss: 0.5781
Epoch [2/5] - Loss: 0.5325
Epoch [3/5] - Loss: 0.4362
Epoch [4/5] - Loss: 0.3445
Epoch [5/5] - Loss: 0.3035


In [17]:
print("\n==================== TRAINING ALEXNET ====================")
train_model(alexnet_model, train_loader_3ch, criterion_alex, optimizer_alex, num_epochs=5)
alexnet_train_acc = test_accuracy(alexnet_model, train_loader_3ch)
alexnet_test_acc = test_accuracy(alexnet_model, test_loader_3ch)


Epoch [1/5] - Loss: 0.4448
Epoch [2/5] - Loss: 0.2106
Epoch [3/5] - Loss: 0.1346
Epoch [4/5] - Loss: 0.1242
Epoch [5/5] - Loss: 0.1191


In [18]:
print("\n==================== TRAINING RESNET ====================")
train_model(resnet_model, train_loader_3ch, criterion_resnet, optimizer_resnet, num_epochs=5)
resnet_train_acc = test_accuracy(resnet_model, train_loader_3ch)
resnet_test_acc = test_accuracy(resnet_model, test_loader_3ch)


Epoch [1/5] - Loss: 0.2595
Epoch [2/5] - Loss: 0.1435
Epoch [3/5] - Loss: 0.1169
Epoch [4/5] - Loss: 0.0756
Epoch [5/5] - Loss: 0.0821


In [19]:
if has_timm:
    print("\n==================== TRAINING VISION TRANSFORMER ====================")
    train_model(vit_model, train_loader_3ch, criterion_vit, optimizer_vit, num_epochs=5)
    vit_train_acc = test_accuracy(vit_model, train_loader_3ch)
    vit_test_acc = test_accuracy(vit_model, test_loader_3ch)
else:
    vit_train_acc = None
    vit_test_acc = None


Epoch [1/5] - Loss: 0.6956
Epoch [2/5] - Loss: 0.5887
Epoch [3/5] - Loss: 0.5815
Epoch [4/5] - Loss: 0.6010
Epoch [5/5] - Loss: 0.5717


In [20]:
############################################################
# 8) SHOW ACCURACIES
############################################################

print("\n==================== ACCURACIES ====================")
header = "| Model               | Train Acc  | Test Acc   |"
line   = "|:--------------------|-----------:|-----------:|"
print(header)
print(line)

print(f"| CNN                 | {cnn_train_acc:.4f}     | {cnn_test_acc:.4f}     |")
print(f"| CNN-LSTM            | {cnn_lstm_train_acc:.4f} | {cnn_lstm_test_acc:.4f} |")
print(f"| AlexNet             | {alexnet_train_acc:.4f}  | {alexnet_test_acc:.4f}  |")
print(f"| ResNet18            | {resnet_train_acc:.4f}   | {resnet_test_acc:.4f}   |")
if has_timm:
    print(f"| ViT (Base/16)       | {vit_train_acc:.4f}      | {vit_test_acc:.4f}      |")

print("\nDone.")


| Model               | Train Acc  | Test Acc   |
|:--------------------|-----------:|-----------:|
| CNN                 | 0.9954     | 0.7756     |
| CNN-LSTM            | 0.8362 | 0.6859 |
| AlexNet             | 0.9736  | 0.7885  |
| ResNet18            | 0.9841   | 0.7997   |
| ViT (Base/16)       | 0.7410      | 0.6346      |

Done.
