## Instructions

To train the model and generate the `vision_model.pth` file, you need to run all the cells in this notebook. You can do this by clicking the "Run All" button in the toolbar above, or by pressing `Ctrl+Alt+Enter`.

# Train Multi-Label Computer Vision Model on Google Colab

This notebook trains a multi-label classifier on the Chest X-ray dataset. It is designed to be run in a Google Colab environment to leverage free GPU resources.

In [11]:
!pip install torch torchvision pandas azure-storage-blob scikit-learn



In [28]:
import os
import io
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from azure.storage.blob import BlobServiceClient

## 1. Enter Your Azure Credentials

Please enter your Azure Storage Account name and a SAS token with read/list permissions for the blob service.

In [29]:
STORAGE_ACCOUNT_NAME = "clinicaldatalake25"  # Your storage account name
SAS_TOKEN = "sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-11-19T06:23:06Z&st=2025-11-18T22:08:06Z&spr=https&sig=RDdiKrvqJzkeU4AsMSvUCs4pSJqfrj3Lbs%2F%2B56PBbWk%3D"  # Paste your Blob SAS Token here
IMAGE_CONTAINER_NAME = "images"
LABELS_CONTAINER_NAME = "labels"

In [None]:
if not SAS_TOKEN:
    raise ValueError("Please paste your SAS_TOKEN into the cell above.")

storage_account_url = f"https://{STORAGE_ACCOUNT_NAME}.blob.core.windows.net"
blob_service_client = BlobServiceClient(account_url=storage_account_url, credential=SAS_TOKEN)

# VALIDATION STEP
try:
    print(f"Testing connection to {STORAGE_ACCOUNT_NAME}...")
    # Try to list containers to verify Read permissions
    containers = list(blob_service_client.list_containers(results_per_page=1))
    print("✅ Connection successful! SAS Token is valid.")
    
    # Ensure the upload container exists NOW, so we don't fail later
    upload_container = "ml-models"
    try:
        blob_service_client.create_container(upload_container)
        print(f"✅ Container '{upload_container}' created.")
    except Exception:
        print(f"✅ Container '{upload_container}' already exists.")
        
except Exception as e:
    print(f"❌ CONNECTION FAILED: {e}")
    raise RuntimeError("Invalid SAS Token or Connection. Please fix credentials before proceeding.")

Testing connection to clinicaldatalake25...
✅ Connection successful! SAS Token is valid.
✅ Container 'ml-models' created.


In [31]:
# Download and load labels
try:
    labels_container_client = blob_service_client.get_container_client(LABELS_CONTAINER_NAME)
    blob_client = labels_container_client.get_blob_client("Data_Entry_2017.csv")
    downloader = blob_client.download_blob()
    labels_df = pd.read_csv(io.BytesIO(downloader.readall()))
    print("Labels loaded successfully.")
except Exception as e:
    print(f"Error loading labels: {e}")
    raise

Labels loaded successfully.


In [32]:
# Get all unique labels
all_labels_str = '|'.join(labels_df['Finding Labels'].dropna())
all_labels = sorted(list(set(all_labels_str.split('|'))))
label_to_int = {label: i for i, label in enumerate(all_labels)}
num_classes = len(all_labels)
print(f"Found {num_classes} unique classes.")

Found 15 unique classes.


In [33]:
# Get list of images from blob storage
print("Listing images in blob storage...")
image_container_client = blob_service_client.get_container_client(IMAGE_CONTAINER_NAME)
blob_list = [blob.name for blob in image_container_client.list_blobs()]
images_df = pd.DataFrame(blob_list, columns=['blob_path'])
images_df['Image Index'] = images_df['blob_path'].apply(lambda x: os.path.basename(x))
print(f"Found {len(images_df)} images in blob storage.")

Listing images in blob storage...
Found 30000 images in blob storage.


In [34]:
# Filter the labels to only include images that are actually in our blob storage
clean_labels_df = labels_df[labels_df['Image Index'].isin(images_df['Image Index'])].copy()
clean_labels_df.reset_index(drop=True, inplace=True)

In [None]:
# Merge to get the full blob path for each labeled image
df = pd.merge(images_df, clean_labels_df, on='Image Index', how='inner')

# SAMPLING STEP
# Limit the dataset size to ensure Colab can handle the training
SAMPLE_SIZE = 2000
if len(df) > SAMPLE_SIZE:
    print(f"Dataset size ({len(df)}) is larger than {SAMPLE_SIZE}. Sampling {SAMPLE_SIZE} images...")
    df = df.sample(n=SAMPLE_SIZE, random_state=42).reset_index(drop=True)
else:
    print(f"Dataset size ({len(df)}) is within limits. Using full dataset.")

print(f"Final dataset size: {len(df)}")

Dataset size (30000) is larger than 2000. Sampling 2000 images...
Final dataset size: 2000


In [36]:
# Split data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

In [37]:
class ChestXrayDataset(Dataset):
    def __init__(self, df, container_client, label_map, transform=None):
        self.df = df
        self.container_client = container_client
        self.label_map = label_map
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['blob_path']
        
        blob_client = self.container_client.get_blob_client(img_path)
        downloader = blob_client.download_blob()
        image_bytes = downloader.readall()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Create the multi-hot encoded label tensor
        labels = row['Finding Labels'].split('|')
        label_tensor = torch.zeros(len(self.label_map), dtype=torch.float32)
        for label in labels:
            if label in self.label_map:
                label_tensor[self.label_map[label]] = 1.0
        
        return image, label_tensor

In [38]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

train_dataset = ChestXrayDataset(train_df, image_container_client, label_to_int, transform=data_transforms['train'])
val_dataset = ChestXrayDataset(val_df, image_container_client, label_to_int, transform=data_transforms['val'])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

In [39]:
# Use ResNet50 as per the project architecture requirements
print("Loading ResNet50 model...")
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

Loading ResNet50 model...
Using device: cuda:0


In [40]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [41]:
print("Starting model training...")
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_dataloader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

    epoch_loss = running_loss / len(train_dataloader)
    print(f'--- End of Epoch [{epoch+1}/{num_epochs}], Average Training Loss: {epoch_loss:.4f} ---')

print('Finished Training')

Starting model training...
Epoch [1/3], Batch [10/50], Loss: 0.1586
Epoch [1/3], Batch [20/50], Loss: 0.2130
Epoch [1/3], Batch [30/50], Loss: 0.2848
Epoch [1/3], Batch [40/50], Loss: 0.1666
Epoch [1/3], Batch [50/50], Loss: 0.1858
--- End of Epoch [1/3], Average Training Loss: 0.2390 ---
Epoch [2/3], Batch [10/50], Loss: 0.1793
Epoch [2/3], Batch [20/50], Loss: 0.2537
Epoch [2/3], Batch [30/50], Loss: 0.2150
Epoch [2/3], Batch [40/50], Loss: 0.2361
Epoch [2/3], Batch [50/50], Loss: 0.2063
--- End of Epoch [2/3], Average Training Loss: 0.2063 ---
Epoch [3/3], Batch [10/50], Loss: 0.1888
Epoch [3/3], Batch [20/50], Loss: 0.1687
Epoch [3/3], Batch [30/50], Loss: 0.1795
Epoch [3/3], Batch [40/50], Loss: 0.1926
Epoch [3/3], Batch [50/50], Loss: 0.1593
--- End of Epoch [3/3], Average Training Loss: 0.2006 ---
Finished Training


In [42]:
torch.save(model.state_dict(), 'vision_model.pth')
print("Model saved to vision_model.pth in the same directory as this notebook.")

Model saved to vision_model.pth in the same directory as this notebook.


## 2. Model Saved Locally

The trained model has been saved as `vision_model.pth` in the same directory as this notebook. You can find it in the file explorer.

In [43]:
# List files in the current directory to confirm the model is saved
!ls -l

total 92268
drwxr-xr-x 1 root root     4096 Nov 17 14:29 sample_data
-rw-r--r-- 1 root root 94474717 Nov 18 23:18 vision_model.pth


## 3. Upload Model to Azure Blob Storage

This step uploads the trained model back to Azure so you can use it in your API.

In [44]:
print("\nUploading model to Azure Blob Storage...")

MODEL_CONTAINER_NAME = "ml-models"
BLOB_NAME = "vision/vision_model.pth"
LOCAL_MODEL_PATH = "vision_model.pth"

try:
    # Create the container if it doesn't exist
    try:
        container_client = blob_service_client.create_container(MODEL_CONTAINER_NAME)
        print(f"Container '{MODEL_CONTAINER_NAME}' created.")
    except Exception:
        # Container likely already exists
        container_client = blob_service_client.get_container_client(MODEL_CONTAINER_NAME)
        print(f"Container '{MODEL_CONTAINER_NAME}' already exists (or access denied to create).")

    # Get blob client
    blob_client = container_client.get_blob_client(BLOB_NAME)

    # Upload
    with open(LOCAL_MODEL_PATH, "rb") as data:
        blob_client.upload_blob(data, overwrite=True)

    print(f"SUCCESS: Model uploaded to container '{MODEL_CONTAINER_NAME}' as '{BLOB_NAME}'")

except Exception as e:
    print(f"ERROR: Failed to upload model to Azure. Details: {e}")
    print("You may need to manually download 'vision_model.pth' from the file explorer on the left.")


Uploading model to Azure Blob Storage...
Container 'ml-models' already exists (or access denied to create).
SUCCESS: Model uploaded to container 'ml-models' as 'vision/vision_model.pth'
