In [1]:
import os 
import torch.distributed as dist
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

In [2]:
import torch.nn as nn
import torch

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

torch.cuda.set_device(0)

class VolumeClassifier(nn.Module):
    def __init__(self, embed_dim=768, hidden_dim=256, num_classes=2):
        super().__init__()
        self.embedder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
        self.lstm = nn.LSTM(input_size=embed_dim, hidden_size=hidden_dim, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        # x: (B, 1, D, H, W)
        B, _, D, H, W = x.shape
        x = x.squeeze(1)  # (B, D, H, W)

        # Reshape to (B * D, 3, H, W)
        x = x.view(B * D, H, W)
        x = x.unsqueeze(1).repeat(1, 3, 1, 1)  # → (B*D, 3, H, W)
        x = self.embedder(x)  # → (B*D, embed_dim)
        x = x.view(B, D, -1)  # → (B, D, embed_dim)

        lstm_out, _ = self.lstm(x)  # → (B, D, hidden_dim)
        final_out = lstm_out[:, -1, :]  # use last LSTM output
        return self.classifier(final_out)
    

In [16]:
from datasets import getNoduleInfoList
nodules = getNoduleInfoList(['zara', 'sclc'])
from datasets import NoduleDataset
dataset = NoduleDataset(
    nodules, 
    isValSet_bool=True,
    dilate=15,
    resample=[224, 224, 224],
    box_size=[64, 64, 64],
    fixed_size=True,
)

from torch.utils.data import DataLoader
val_dl = DataLoader(
    dataset,
    batch_size=1,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    shuffle=True,
)

2025-08-13 21:09:18,314 INFO     pid:3383353 datasets:244:__init__ <datasets.NoduleDataset object at 0x7fc2eb79bf40>: 499 validation samples


In [4]:
import torch.optim as optim
import torch.nn as nn

def train(rank, world_size):
    
    setup(rank, world_size)
    model = VolumeClassifier()
    model = FSDP(model, device_id=rank)
    # model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    num_epochs = 100

    for epoch in range(num_epochs):
        model.train()  # set model to training mode
        running_loss = 0.0

        for volumes, labels in val_dl:
            volumes = volumes.to(rank)
            labels = labels.to(rank)

            optimizer.zero_grad()

            outputs = model(volumes)  # forward pass
            loss = criterion(outputs, labels)  # compute loss

            loss.backward()  # backprop
            optimizer.step()  # update weights

            running_loss += loss.item() * volumes.size(0)

        epoch_loss = running_loss / len(val_dl.dataset)
        if rank == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    cleanup()



In [5]:
import torch.multiprocessing as mp 
mp.spawn(train,
             args=(4,),
             nprocs=4,
             join=True)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/data/kaplinsp/envs/lminfer/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/data/kaplinsp/envs/lminfer/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/data/kaplinsp/envs/lminfer/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/data/kaplinsp/envs/lminfer/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/data/kaplinsp/envs/lminfer/lib/python3.

ProcessExitedException: process 0 terminated with exit code 1

In [9]:
import matplotlib.pyplot as plt 
for images, labels in val_dl:
    plt.imshow(images[0][0][112], cmap="grey")
    break

2025-05-29 21:21:13,835 INFO     pid:1854998 datasets:116:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462715.nrrd
2025-05-29 21:21:13,921 INFO     pid:1854999 datasets:116:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A60.nrrd
2025-05-29 21:21:14,029 INFO     pid:1855001 datasets:116:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462701.nrrd
2025-05-29 21:21:14,030 INFO     pid:1855000 datasets:116:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462704.nrrd
2025-05-29 21:21:15,547 INFO     pid:1854998 datasets:116:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/S462749.nrrd
2025-05-29 21:21:15,590 INFO     pid:1854999 datasets:116:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462734.nrrd
2025-05-29 21:21:15,821 INFO     pid:1855001 datasets:116:get_fixed_size_nodule Slicin

KeyboardInterrupt: 

In [17]:
embeddings = []
labels = []

def preprocess_volume(volume):  # volume: [1, H, L, W]
    # Step 1: Squeeze channel
    volume = volume.squeeze(0)  # Now [H, L, W]
    
    # Step 2: Get center slice along depth (axis=0)
    center_idx = volume.shape[0] // 2
    center_slice = volume[center_idx]  # Shape: [L, W] → becomes [H, W]
    
    # Step 3: Convert to 3-channel image
    slice_3ch = center_slice.unsqueeze(0).repeat(3, 1, 1)  # Shape: [3, H, W]
    
    return slice_3ch

dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov2_vitb14.eval().cuda()
with torch.no_grad():
    for volumes, lbls in val_dl:  # volumes: [B, 1, H, L, W]
        batch_feats = []
        for vol in volumes:  # loop through batch
            slice_img = preprocess_volume(vol)  # [3, H, W]
            input_img = slice_img.unsqueeze(0).cuda()  # Add batch dim
            
            feat = dinov2_vitb14(input_img)  # [1, D]
            batch_feats.append(feat.cpu())
        
        batch_feats = torch.cat(batch_feats, dim=0)  # [B, D]
        embeddings.append(batch_feats)
        labels.append(lbls)

embeddings = torch.cat(embeddings)
labels = torch.cat(labels)

Using cache found in /home/kaplinsp/.cache/torch/hub/facebookresearch_dinov2_main
2025-08-13 21:09:35,244 INFO     pid:4016336 datasets:096:get_fixed_size_nodule Slicing nodule from image for /data/shastra1/Data_zara/NSCLC_Zara/NSCLC_092.nrrd
2025-08-13 21:09:35,244 INFO     pid:4016335 datasets:096:get_fixed_size_nodule Slicing nodule from image for /data/shastra1/Data_zara/SC_Zara/SC_017 .nrrd
2025-08-13 21:09:35,249 INFO     pid:4016338 datasets:096:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462726.nrrd
2025-08-13 21:09:35,405 INFO     pid:4016337 datasets:096:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462730.nrrd
2025-08-13 21:09:36,630 INFO     pid:4016335 datasets:096:get_fixed_size_nodule Slicing nodule from image for /data/kaplinsp/transformation/A462734.nrrd
2025-08-13 21:09:36,791 INFO     pid:4016338 datasets:096:get_fixed_size_nodule Slicing nodule from image for /data/shastra1/Data_zara/NSCLC_Z

In [18]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    roc_auc_score, precision_score, recall_score, f1_score, classification_report, accuracy_score
)
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(embeddings, labels, test_size=0.2, random_state=42)

clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)

# Evaluate
y_pred = clf.predict(X_test)
# Precision, Recall, F1
precision = precision_score(y_test, y_pred, average='macro')
recall = recall_score(y_test, y_pred, average='macro')
f1 = f1_score(y_test, y_pred, average='macro')
roc_auc = roc_auc_score(y_test, y_pred)

# Report
print(f"Test Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(f"ROC AUC:       {roc_auc:.4f}")
print(f"Precision:     {precision:.4f}")
print(f"Recall:        {recall:.4f}")
print(f"F1 Score:      {f1:.4f}")

# Optional: full report per class
print("\nClassification Report:\n", classification_report(y_test, y_pred))


Test Accuracy: 0.5600
ROC AUC:       0.5583
Precision:     0.5561
Recall:        0.5583
F1 Score:      0.5536

Classification Report:
               precision    recall  f1-score   support

           0       0.65      0.57      0.61        60
           1       0.46      0.55      0.50        40

    accuracy                           0.56       100
   macro avg       0.56      0.56      0.55       100
weighted avg       0.58      0.56      0.56       100

