In [None]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

In [None]:
class_names = ['normal', 'viral', 'covid']
root_dir = 'COVID-19 Radiography Database'
source_dirs = ['NORMAL', 'Viral Pneumonia', 'COVID-19']

if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
    os.mkdir(os.path.join(root_dir, 'test'))

    for i, d in enumerate(source_dirs):
        os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))

    for c in class_names:
        os.mkdir(os.path.join(root_dir, 'test', c))

    for c in class_names:
        images = [x for x in os.listdir(os.path.join(root_dir, c)) if x.lower().endswith('png')]
        selected_images = random.sample(images, 30)
        for image in selected_images:
            source_path = os.path.join(root_dir, c, image)
            target_path = os.path.join(root_dir, 'test', c, image)
            shutil.move(source_path, target_path)

In [None]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs, transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        self.images = {}
        self.class_names = ['normal', 'viral', 'covid']
        
        for class_name in self.class_names:
            self.images[class_name] = get_images(class_name)
            
        self.image_dirs = image_dirs
        self.transform = transform
        
    
    def __len__(self):
        return sum([len(self.images[class_name]) for class_name in self.class_names])
    
    
    def __getitem__(self, index):
        class_name = random.choice(self.class_names)
        index = index % len(self.images[class_name])
        image_name = self.images[class_name][index]
        image_path = os.path.join(self.image_dirs[class_name], image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.class_names.index(class_name)

In [None]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
train_dirs = {
    'normal': 'COVID-19 Radiography Database/normal',
    'viral': 'COVID-19 Radiography Database/viral',
    'covid': 'COVID-19 Radiography Database/covid'
}

train_dataset = ChestXRayDataset(train_dirs, train_transform)

In [None]:
test_dirs = {
    'normal': 'COVID-19 Radiography Database/test/normal',
    'viral': 'COVID-19 Radiography Database/test/viral',
    'covid': 'COVID-19 Radiography Database/test/covid'
}

test_dataset = ChestXRayDataset(test_dirs, test_transform)

In [None]:
batch_size = 6

dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print('Number of training batches', len(dl_train))
print('Number of test batches', len(dl_test))

In [None]:
class_names = train_dataset.class_names


def show_images(images, labels, preds):
    plt.figure(figsize=(8, 4))
    for i, image in enumerate(images):
        plt.subplot(1, 6, i + 1, xticks=[], yticks=[])
        image = image.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image * std + mean
        image = np.clip(image, 0., 1.)
        plt.imshow(image)
        col = 'green'
        if preds[i] != labels[i]:
            col = 'red'
            
        plt.xlabel(f'{class_names[int(labels[i].numpy())]}')
        plt.ylabel(f'{class_names[int(preds[i].numpy())]}', color=col)
    plt.tight_layout()
    plt.show()

In [None]:
images, labels = next(iter(dl_train))
show_images(images, labels, labels)

In [None]:
images, labels = next(iter(dl_test))
show_images(images, labels, labels)

In [None]:
resnet18 = torchvision.models.resnet18(pretrained=True)

print(resnet18)

In [None]:
resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=3e-5)

In [None]:
def show_preds():
    resnet18.eval()
    images, labels = next(iter(dl_test))
    outputs = resnet18(images)
    _, preds = torch.max(outputs, 1)
    show_images(images, labels, preds)

In [None]:
show_preds()

In [None]:
def train(epochs):
    print('Starting training..')
    for e in range(0, epochs):
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)

        train_loss = 0.
        val_loss = 0.

        resnet18.train() # set model to training phase

        for train_step, (images, labels) in enumerate(dl_train):
            optimizer.zero_grad()
            outputs = resnet18(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if train_step % 20 == 0:
                print('Evaluating at step', train_step)

                accuracy = 0

                resnet18.eval() # set model to eval phase

                for val_step, (images, labels) in enumerate(dl_test):
                    outputs = resnet18(images)
                    loss = loss_fn(outputs, labels)
                    val_loss += loss.item()

                    _, preds = torch.max(outputs, 1)
                    accuracy += sum((preds == labels).numpy())

                val_loss /= (val_step + 1)
                accuracy = accuracy/len(test_dataset)
                print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')

                show_preds()

                resnet18.train()

                if accuracy >= 0.95:
                    print('Performance condition satisfied, stopping..')
                    return

        train_loss /= (train_step + 1)

        print(f'Training Loss: {train_loss:.4f}')
    print('Training complete..')

In [None]:
%%time

train(epochs=1)

In [None]:
show_preds()

In [None]:
# Evaluation: accuracy, confusion matrix, per-class precision/recall
import torch
import matplotlib.pyplot as plt
import numpy as np

# Try to reuse existing device; else define it
try:
    device
except NameError:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Try to infer class names from dataset; fallback to numeric
try:
    ds = getattr(test_loader, 'dataset', None)
    class_names = None
    if ds is not None:
        if hasattr(ds, 'classes') and isinstance(ds.classes, (list, tuple)):
            class_names = list(ds.classes)
        elif hasattr(ds, 'class_to_idx') and isinstance(ds.class_to_idx, dict):
            idx_to_class = {v: k for k, v in ds.class_to_idx.items()}
            class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
except Exception:
    class_names = None

# Fallback: infer number of classes from model output
try:
    num_classes = resnet18.fc.out_features
except Exception:
    num_classes = len(class_names) if class_names is not None else 3

if class_names is None:
    class_names = [str(i) for i in range(num_classes)]

# Compute confusion matrix and accuracy
conf_mat = torch.zeros(num_classes, num_classes, dtype=torch.int64)
correct = 0
total = 0

resnet18.eval()
with torch.no_grad():
    for batch in test_loader:
        # Support datasets that return (images, labels) or dicts
        if isinstance(batch, (list, tuple)) and len(batch) >= 2:
            images, labels = batch[0], batch[1]
        elif isinstance(batch, dict):
            images, labels = batch.get('images'), batch.get('labels')
        else:
            raise RuntimeError('Unexpected batch format from test_loader')

        images = images.to(device)
        labels = labels.to(device)
        outputs = resnet18(images)
        preds = outputs.argmax(dim=1)

        # Update confusion matrix
        for t, p in zip(labels.view(-1), preds.view(-1)):
            conf_mat[int(t.item()), int(p.item())] += 1

        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = (100.0 * correct / total) if total else 0.0
print(f'Test accuracy: {accuracy:.2f}%')

# Per-class precision and recall
tp = conf_mat.diag().to(torch.float32)
precision = tp / conf_mat.sum(0).clamp(min=1)
recall = tp / conf_mat.sum(1).clamp(min=1)

for i, name in enumerate(class_names):
    print(f'{name}: precision={precision[i].item():.3f}, recall={recall[i].item():.3f}')

# Optional: visualize confusion matrix
fig, ax = plt.subplots(figsize=(6, 5))
cax = ax.imshow(conf_mat.cpu().numpy(), cmap='Blues')
ax.set_title('Confusion Matrix')
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_xticks(range(num_classes))
ax.set_yticks(range(num_classes))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.set_yticklabels(class_names)
fig.colorbar(cax)
plt.tight_layout()
plt.show()

In [None]:
# Save and Load Model; TorchScript export
import os
import torch
import torchvision

# Ensure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Ensure models directory exists
os.makedirs('models', exist_ok=True)

# Save current model weights
ckpt_path = os.path.join('models', 'resnet18_covid_xray.pth')
try:
    torch.save(resnet18.state_dict(), ckpt_path)
    print(f'Saved weights to: {ckpt_path}')
except Exception as e:
    print('Failed to save weights:', e)

# Demonstrate reloading into a fresh model
try:
    fresh_model = torchvision.models.resnet18(weights=None)
except TypeError:
    # Fallback for older torchvision versions
    fresh_model = torchvision.models.resnet18(pretrained=False)

# Match the classifier layer dimensions to original model
try:
    out_feats = resnet18.fc.out_features
    fresh_model.fc = torch.nn.Linear(fresh_model.fc.in_features, out_feats)
except Exception as e:
    print('Warning: could not match classifier layer automatically:', e)

try:
    state = torch.load(ckpt_path, map_location=device)
    fresh_model.load_state_dict(state)
    fresh_model.to(device).eval()
    print('Reloaded weights into fresh model and set to eval().')
except Exception as e:
    print('Failed to reload weights:', e)

# Optional: export TorchScript for deployment
script_path = os.path.join('models', 'resnet18_covid_xray_script.pt')
try:
    example = torch.randn(1, 3, 224, 224).to(device)
    resnet18.eval().to(device)
    scripted = torch.jit.trace(resnet18, example)
    scripted.save(script_path)
    print(f'Saved TorchScript module to: {script_path}')
except Exception as e:
    print('TorchScript export failed:', e)