In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from itertools import product


In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_root = '/central/groups/CS156b/2025/CodeMonkeys/input_images'
target_col = 'Lung Opacity'
image_root_dir = "input_images/train"
model_save_dir = 'maya_models/grid_search'
os.makedirs(model_save_dir, exist_ok=True)

end_df = 20000
uncertain_weight_factor = 0.25
neg_cutoff = 0.25
pos_cutoff = 0.75

num_epochs = 5  # Shorter for grid search
freeze_until = 4
batch_size = 16

In [43]:
class CSVDataset(Dataset):
    def __init__(self, dataframe, image_root_dir, target_columns=None, transform=None, save_dir=None, use_saved_images=False):
        self.data = dataframe
        self.image_root_dir = image_root_dir
        self.target_columns = target_columns
        self.transform = transform
        self.save_dir = save_dir
        self.use_saved_images = use_saved_images
        if self.save_dir:
            os.makedirs(self.save_dir, exist_ok=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_index = row['Unnamed: 0']
        saved_image_path = os.path.join(self.save_dir, f"{image_index}.pt")
        if self.use_saved_images and os.path.exists(saved_image_path):
            image_tensor = torch.load(saved_image_path)
        else:
            original_image_path = os.path.join(self.image_root_dir, row['Path'])
            image = Image.open(original_image_path).convert("L")
            image_tensor = self.transform(image) if self.transform else transforms.ToTensor()(image)
            if self.save_dir:
                torch.save(image_tensor, saved_image_path)
        labels = pd.to_numeric(row[self.target_columns], errors='coerce').fillna(0).astype(float).values
        return image_tensor, torch.tensor(labels, dtype=torch.float32)

In [44]:
class MultiLabelResNet50(nn.Module):
    def __init__(self, num_classes, hidden_size, dropout):  # <-- add dropout here
        super().__init__()
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),  # <-- use dropout here
            nn.Linear(hidden_size, num_classes),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.base_model(x)


class MultiLabelDenseNet121(nn.Module):
    def __init__(self, num_classes, hidden_size, dropout):  # <-- add dropout here
        super().__init__()
        self.base_model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.base_model.classifier = nn.Sequential(
            nn.Linear(self.base_model.classifier.in_features, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),  # <-- use dropout here
            nn.Linear(hidden_size, num_classes),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.base_model(x)


def freeze_base_layers(model, until_layer=6):
    child_counter = 0
    for child in model.base_model.children():
        if child_counter < until_layer:
            for param in child.parameters():
                param.requires_grad = False
        child_counter += 1
    return model

In [None]:
csv_path = '/central/groups/CS156b/2025/CodeMonkeys/train2023.csv'

def get_filtered_df(col, num=None, csv_path='train2023.csv'):
    df = pd.read_csv(csv_path)
    if num:
        df = df.iloc[:num]
    df = df.dropna(subset=[col]).copy()
    df[col] = (df[col] + 1) / 2
    return df


filtered_df = get_filtered_df(target_col, num=end_df, csv_path='/central/groups/CS156b/2025/CodeMonkeys/train2023.csv')
print(f"Total images with valid '{target_col}' label: {len(filtered_df)}")

target_columns = [target_col]

train_df, val_df = train_test_split(filtered_df, test_size=0.15, random_state=42)
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])

train_dataset = CSVDataset(train_df, image_root, target_columns, transform, save_dir=os.path.join(image_root, 'train'), use_saved_images=True)
val_dataset = CSVDataset(val_df, image_root, target_columns, transform, save_dir=os.path.join(image_root, 'train'), use_saved_images=True)

lo_labels = train_df[target_col].values
label_map = {0.0: 0, 0.5: 1, 1.0: 2}
mapped_labels = np.array([label_map[float(lbl)] for lbl in lo_labels])
class_counts = np.bincount(mapped_labels)
weights = 1. / (class_counts + 1e-6)
sample_weights = torch.tensor(weights[mapped_labels], dtype=torch.float)
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

class_weights = {target_col: {0: 1.0, 0.5: uncertain_weight_factor, 1: 1.0}}

criterion = nn.MSELoss(reduction='none')

def masked_MSE_loss(output, target, class_weights):
    mask = ~torch.isnan(target)
    loss = criterion(output, target)
    for class_idx, col in enumerate(target_columns):
        class_values = target[:, class_idx]
        weight = torch.tensor([class_weights[col].get(x.item(), 1) for x in class_values], dtype=torch.float32, device=output.device)
        loss = loss * mask
        loss[:, class_idx] *= weight
    return loss.sum() / mask.sum()

def masked_MSE_loss_2(output, target, class_weights):
    # output: (batch_size, 1)
    # target: (batch_size, 1)
    mask = ~torch.isnan(target)

    # Basic unweighted MSE loss (elementwise)
    loss = criterion(output, target)  # shape: (batch_size, 1)

    # Get class values as floats: 0.0, 0.5, 1.0
    class_vals = target.view(-1)
    weights = torch.tensor(
        [class_weights[target_columns[0]].get(x.item(), 1.0) for x in class_vals],
        dtype=torch.float32, device=output.device
    ).view(-1, 1)  # same shape as loss

    # Apply weights and mask
    weighted_loss = loss * weights * mask
    print("Sample weights used:", weights.view(-1)[:10])


    return weighted_loss.sum() / mask.sum()


#--------------------------------------------TRAINING------------------------------------------------------#

def train_one_model(model_type, hidden_size, dropout, freeze_until, lr, weight_decay):
    if model_type == "resnet50":
        model = MultiLabelResNet50(num_classes=1, hidden_size=hidden_size, dropout=dropout).to(device)
        model = freeze_base_layers(model, until_layer=freeze_until)
    else:
        model = MultiLabelDenseNet121(num_classes=1, hidden_size=hidden_size, dropout=dropout).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)

            # Use only one loss for optimization
            loss = masked_MSE_loss(outputs, labels, class_weights)
            loss_2 = masked_MSE_loss_2(outputs, labels, class_weights)  # Just for tracking

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            preds = torch.where(outputs < neg_cutoff, 0.0,
                    torch.where(outputs < pos_cutoff, 0.5, 1.0))
            correct += (preds == labels).sum().item()
            total += labels.numel()

        avg_train_loss = running_loss / len(train_loader)

        # ---- VALIDATION ----
        model.eval()
        val_loss, val_loss_2, val_correct, val_total = 0.0, 0.0, 0, 0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)

                loss = masked_MSE_loss(outputs, labels, class_weights)
                loss_2 = masked_MSE_loss_2(outputs, labels, class_weights)

                val_loss += loss.item()
                val_loss_2 += loss_2.item()

                preds = torch.where(outputs < neg_cutoff, 0.0,
                        torch.where(outputs < pos_cutoff, 0.5, 1.0))

                val_correct += (preds == labels).sum().item()
                val_total += labels.numel()

                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        avg_val_loss_2 = val_loss_2 / len(val_loader)
        val_accuracy = val_correct / val_total

        print(f"{model_type} H{hidden_size} D{dropout} F{freeze_until} LR{lr} WD{weight_decay} | Epoch {epoch+1}: "
              f"ValLoss={avg_val_loss:.4f}, ValLoss2={avg_val_loss_2:.4f}, ValAcc={val_accuracy:.4f}")

        # ---- CONFUSION MATRIX ----
        float_to_int = {0.0: 0, 0.5: 1, 1.0: 2}
        pred_classes = np.array([float_to_int.get(float(v), -1) for v in np.concatenate(all_preds).flatten()])
        true_classes = np.array([float_to_int.get(float(v), -1) for v in np.concatenate(all_labels).flatten()])

        from sklearn.metrics import confusion_matrix
        import seaborn as sns
        import matplotlib.pyplot as plt

        cm = confusion_matrix(true_classes, pred_classes, labels=[0, 1, 2])
        plt.figure(figsize=(5, 4))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=["Neg", "Unc", "Pos"],
                    yticklabels=["Neg", "Unc", "Pos"])
        plt.title(f'CM: {model_type}, H={hidden_size}, D={dropout}, Epoch {epoch+1}')
        plt.xlabel("Predicted")
        plt.ylabel("True")
        cm_filename = f"cm_{model_type}_h{hidden_size}_d{dropout}_frz{freeze_until}_lr{lr}_wd{weight_decay}_epoch{epoch+1}.png"
        plt.savefig(os.path.join(model_save_dir, cm_filename))
        plt.close()

        # ---- EARLY STOPPING ----
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(),
                       os.path.join(model_save_dir, f"{model_type}_h{hidden_size}_lr{lr}_wd{weight_decay}_best.pth"))
        else:
            patience_counter += 1
            if patience_counter >= 2:
                print("⏹️ Early stopping triggered.")
                break

    return best_val_loss, val_accuracy



Total images with valid 'Lung Opacity' label: 10147


In [46]:
from itertools import product

def grid_search():
    model_types = ['resnet50', 'densenet121']
    hidden_sizes = [256, 384, 512]
    lrs = [1e-4, 5e-4]
    wds = [1e-5, 1e-4]
    frozen_layers = [3, 4, 5]
    dropouts = [0.2, 0.4, 0.6]

    best_loss = float('inf')
    best_config = None
    results = []

    # Unpack all 6 hyperparameters from grid
    for m, h, lr, wd, fl, ds in product(model_types, hidden_sizes, lrs, wds, frozen_layers, dropouts):
        print(f"\n🔍 Training config: model={m}, hidden={h}, dropout={ds}, freeze_until={fl}, lr={lr}, wd={wd}")
        val_loss, val_acc = train_one_model(m, h, ds, fl, lr, wd)
        results.append((m, h, ds, fl, lr, wd, val_loss, val_acc))

        if val_loss < best_loss:
            best_loss = val_loss
            best_config = (m, h, ds, fl, lr, wd)

    print(f"\n✅ Best Config: {best_config} with ValLoss {best_loss:.4f}")
    pd.DataFrame(results, columns=["Model", "Hidden", "Dropout", "FreezeUntil", "LR", "WD", "ValLoss", "ValAcc"]).to_csv(
        os.path.join(model_save_dir, "grid_results.csv"), index=False
    )


In [None]:
grid_search()



🔍 Training config: model=resnet50, hidden=256, dropout=0.2, freeze_until=3, lr=0.0001, wd=1e-05
resnet50 H256 D0.2 F3 LR0.0001 WD1e-05 | Epoch 1: ValLoss=0.0809, ValLoss2=0.0809, ValAcc=0.7912
resnet50 H256 D0.2 F3 LR0.0001 WD1e-05 | Epoch 2: ValLoss=0.0679, ValLoss2=0.0679, ValAcc=0.7873
resnet50 H256 D0.2 F3 LR0.0001 WD1e-05 | Epoch 3: ValLoss=0.0693, ValLoss2=0.0693, ValAcc=0.8024
resnet50 H256 D0.2 F3 LR0.0001 WD1e-05 | Epoch 4: ValLoss=0.0682, ValLoss2=0.0682, ValAcc=0.7991
⏹️ Early stopping triggered.

🔍 Training config: model=resnet50, hidden=256, dropout=0.4, freeze_until=3, lr=0.0001, wd=1e-05
resnet50 H256 D0.4 F3 LR0.0001 WD1e-05 | Epoch 1: ValLoss=0.1020, ValLoss2=0.1020, ValAcc=0.7032
resnet50 H256 D0.4 F3 LR0.0001 WD1e-05 | Epoch 2: ValLoss=0.0704, ValLoss2=0.0704, ValAcc=0.7965
resnet50 H256 D0.4 F3 LR0.0001 WD1e-05 | Epoch 3: ValLoss=0.0868, ValLoss2=0.0868, ValAcc=0.7505
resnet50 H256 D0.4 F3 LR0.0001 WD1e-05 | Epoch 4: ValLoss=0.0777, ValLoss2=0.0777, ValAcc=0.8056
⏹