In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import os
from PIL import Image
from sklearn.metrics import roc_auc_score, classification_report, accuracy_score, f1_score, precision_score, recall_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from transformers import ViTForImageClassification, ViTFeatureExtractor
import matplotlib.pyplot as plt

In [None]:
from huggingface_hub import login

login("")


In [None]:
import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from PIL import Image

# need to specify MLP layer and activation function for proper init
model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
model = model.eval()

transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

image = Image.open("/home/iambrink/NOH_Thyroid_Cancer_Data/TAN/001/IMG_20220623_134910.jpg")
image = transforms(image).unsqueeze(0)  # size: 1 x 3 x 224 x 224

output = model(image)  # size: 1 x 261 x 1280

class_token = output[:, 0]    # size: 1 x 1280
patch_tokens = output[:, 5:]  # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those

# concatenate class token and average pool of patch tokens
embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)  # size: 1 x 2560


In [None]:
model = model.to("cuda")
image = image.to("cuda")

with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
  output = model(image)

class_token = output[:, 0]
patch_tokens = output[:, 5:]

embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)

# the model output will be fp32 because the final operation is a LayerNorm that is ran in mixed precision
# optionally, you can convert the embedding to fp16 for efficiency in downstream use
embedding = embedding.to(torch.float16)


In [None]:
model = timm.create_model("hf-hub:paige-ai/Virchow2")


# training

semi faster one (32 batch)

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Enable benchmark mode in cuDNN
torch.backends.cudnn.benchmark = True

# Set paths
data_path = "/home/iambrink/NOH_Thyroid_Cancer_Data/CSV-files/Thyroid_Cancer_TAN&NOH_file.csv"
base_image_path = "/home/iambrink/NOH_Thyroid_Cancer_Data/superdata/"

# Load and clean the dataset
df = pd.read_csv(data_path)
df = df.dropna(subset=["Surgery diagnosis in number"])

# Split into train and validation sets (80% train, 20% validation)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df["Surgery diagnosis in number"])

# Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, dataframe, base_path, transform=None):
        self.dataframe = dataframe
        self.base_path = base_path
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = os.path.join(self.base_path, self.dataframe.iloc[idx]["image_path"].replace("\\", "/"))
        image = Image.open(img_path).convert("RGB")
        label = int(self.dataframe.iloc[idx]["Surgery diagnosis in number"])
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

# Load the pre-trained model
model = timm.create_model(
    "hf-hub:paige-ai/Virchow2", 
    pretrained=True, 
    mlp_layer=SwiGLUPacked, 
    act_layer=torch.nn.SiLU
)

# Modify the model for binary classification
num_features = model.num_features
model.head = nn.Linear(num_features, 2)

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define transformations using timm's data config
config = resolve_data_config(model.pretrained_cfg, model=model)
transform = create_transform(**config)

# Create datasets
train_dataset = CustomDataset(train_df, base_image_path, transform=transform)
val_dataset = CustomDataset(val_df, base_image_path, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True, persistent_workers=True)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True)

# Automatic mixed precision
scaler = torch.cuda.amp.GradScaler()

# Training setup
num_epochs = 25
best_val_acc = 0.0  # Store the best validation accuracy

for epoch in range(num_epochs):
    model.train()
    train_loss, correct, total = 0, 0, 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)  # Expected shape: [batch_size, 261, 2]
            outputs = outputs[:, 0, :]  # Use the class token [batch_size, 2]
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), acc=100. * correct / total)

    train_acc = 100. * correct / total
    print(f"Epoch {epoch+1}: Train Loss={train_loss/len(train_loader):.4f}, Train Accuracy={train_acc:.2f}%")

    # Validation phase
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)[:, 0, :]
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100. * val_correct / val_total
    print(f"Epoch {epoch+1}: Validation Loss={val_loss/len(val_loader):.4f}, Validation Accuracy={val_acc:.2f}%")

    # Step scheduler based on validation accuracy
    scheduler.step(val_acc)

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "virchow2_best_test3.pth")
        print(f"New best model saved with accuracy: {best_val_acc:.2f}%")

print("Training Complete. Best model saved as 'virchow2_best_test.pth'.")


  from .autonotebook import tqdm as notebook_tqdm
  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Epoch 1/25: 100%|██████████| 41/41 [1:06:19<00:00, 97.07s/it, acc=47.5, loss=1.24] 

Epoch 1: Train Loss=1.9950, Train Accuracy=47.46%





Epoch 1: Validation Loss=0.7286, Validation Accuracy=52.62%
New best model saved with accuracy: 52.62%


Epoch 2/25: 100%|██████████| 41/41 [1:04:50<00:00, 94.88s/it, acc=51, loss=0.707]  


Epoch 2: Train Loss=0.7347, Train Accuracy=51.00%
Epoch 2: Validation Loss=0.7040, Validation Accuracy=52.62%


Epoch 3/25: 100%|██████████| 41/41 [1:06:48<00:00, 97.76s/it, acc=51.9, loss=0.776]


Epoch 3: Train Loss=0.7181, Train Accuracy=51.93%
Epoch 3: Validation Loss=0.7215, Validation Accuracy=52.62%


Epoch 4/25: 100%|██████████| 41/41 [1:06:41<00:00, 97.61s/it, acc=50.1, loss=0.811]


Epoch 4: Train Loss=0.7299, Train Accuracy=50.08%
Epoch 4: Validation Loss=0.7711, Validation Accuracy=47.38%


Epoch 5/25: 100%|██████████| 41/41 [1:06:44<00:00, 97.68s/it, acc=49.9, loss=0.682]


Epoch 5: Train Loss=0.7093, Train Accuracy=49.92%
Epoch 5: Validation Loss=0.7075, Validation Accuracy=47.38%


Epoch 6/25: 100%|██████████| 41/41 [1:06:39<00:00, 97.55s/it, acc=50.5, loss=0.735]


Epoch 6: Train Loss=0.7104, Train Accuracy=50.54%
Epoch 6: Validation Loss=0.7013, Validation Accuracy=52.62%


Epoch 7/25: 100%|██████████| 41/41 [1:06:39<00:00, 97.54s/it, acc=50.4, loss=0.688]


Epoch 7: Train Loss=0.7016, Train Accuracy=50.39%
Epoch 7: Validation Loss=0.6927, Validation Accuracy=52.62%


Epoch 8/25: 100%|██████████| 41/41 [1:06:43<00:00, 97.64s/it, acc=51.2, loss=0.704]


Epoch 8: Train Loss=0.6977, Train Accuracy=51.16%
Epoch 8: Validation Loss=0.6910, Validation Accuracy=52.62%


Epoch 9/25: 100%|██████████| 41/41 [1:06:45<00:00, 97.69s/it, acc=50.5, loss=0.71] 


Epoch 9: Train Loss=0.6988, Train Accuracy=50.54%
Epoch 9: Validation Loss=0.7021, Validation Accuracy=52.62%


Epoch 10/25: 100%|██████████| 41/41 [1:06:39<00:00, 97.55s/it, acc=50.5, loss=0.76] 


Epoch 10: Train Loss=0.7305, Train Accuracy=50.54%
Epoch 10: Validation Loss=0.7070, Validation Accuracy=47.38%


Epoch 11/25: 100%|██████████| 41/41 [1:05:37<00:00, 96.03s/it, acc=48.7, loss=nan] 


Epoch 11: Train Loss=nan, Train Accuracy=48.69%
Epoch 11: Validation Loss=0.7189, Validation Accuracy=52.62%


Epoch 12/25: 100%|██████████| 41/41 [1:05:33<00:00, 95.94s/it, acc=49.3, loss=nan]


Epoch 12: Train Loss=nan, Train Accuracy=49.27%
Epoch 12: Validation Loss=0.7189, Validation Accuracy=52.62%


Epoch 13/25: 100%|██████████| 41/41 [1:05:43<00:00, 96.18s/it, acc=48.5, loss=nan]  


Epoch 13: Train Loss=nan, Train Accuracy=48.54%
Epoch 13: Validation Loss=0.7183, Validation Accuracy=52.62%


Epoch 14/25: 100%|██████████| 41/41 [1:05:24<00:00, 95.72s/it, acc=47.3, loss=nan]


Epoch 14: Train Loss=nan, Train Accuracy=47.27%
Epoch 14: Validation Loss=0.7183, Validation Accuracy=52.62%


Epoch 15/25: 100%|██████████| 41/41 [1:05:26<00:00, 95.76s/it, acc=47.3, loss=nan]


Epoch 15: Train Loss=nan, Train Accuracy=47.27%
Epoch 15: Validation Loss=0.7183, Validation Accuracy=52.62%


Epoch 16/25: 100%|██████████| 41/41 [1:05:33<00:00, 95.95s/it, acc=47.3, loss=nan]


Epoch 16: Train Loss=nan, Train Accuracy=47.27%
Epoch 16: Validation Loss=0.7183, Validation Accuracy=52.62%


Epoch 17/25:  90%|█████████ | 37/41 [59:32<06:26, 96.71s/it, acc=47.5, loss=nan] 

# k-fold

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

# ————————————————————————————————————————————————————————————————
# 1. Enable cuDNN benchmark for potential speedups on fixed-size inputs
# ————————————————————————————————————————————————————————————————
torch.backends.cudnn.benchmark = True

# ————————————————————————————————————————————————————————————————
# 2. Paths and hyperparameters
# ————————————————————————————————————————————————————————————————
DATA_CSV     = "/home/iambrink/NOH_Thyroid_Cancer_Data/CSV-files/Thyroid_Cancer_TAN&NOH_file.csv"
BASE_IMG_DIR = "/home/iambrink/NOH_Thyroid_Cancer_Data/superdata/"

NUM_FOLDS    = 5
NUM_EPOCHS   = 25
BATCH_SIZE   = 64
LR           = 1e-5
WEIGHT_DECAY = 1e-5
NUM_WORKERS  = 16

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ————————————————————————————————————————————————————————————————
# 3. Load and clean the full DataFrame
# ————————————————————————————————————————————————————————————————
df = pd.read_csv(DATA_CSV)
df = df.dropna(subset=["Surgery diagnosis in number"])
labels = df["Surgery diagnosis in number"].astype(int).values

# ————————————————————————————————————————————————————————————————
# 4. Custom Dataset definition
# ————————————————————————————————————————————————————————————————
class CustomDataset(Dataset):
    def __init__(self, dataframe, base_path, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.base_path = base_path
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row      = self.dataframe.iloc[idx]
        img_path = os.path.join(self.base_path, row["image_path"].replace("\\", "/"))
        image    = Image.open(img_path).convert("RGB")
        label    = int(row["Surgery diagnosis in number"])
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

# ————————————————————————————————————————————————————————————————
# 5. Prepare image transforms once (use a dummy model for config)
# ————————————————————————————————————————————————————————————————
_dummy_model = timm.create_model(
    "hf-hub:paige-ai/Virchow2",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=torch.nn.SiLU
)
data_config = resolve_data_config(_dummy_model.pretrained_cfg, model=_dummy_model)
transform   = create_transform(**data_config)
del _dummy_model

# ————————————————————————————————————————————————————————————————
# 6. Set up Stratified K‑Fold
# ————————————————————————————————————————————————————————————————
skf = StratifiedKFold(
    n_splits=NUM_FOLDS,
    shuffle=True,
    random_state=42
)

fold_accuracies = []

# ————————————————————————————————————————————————————————————————
# 7. Cross‑validation loop
# ————————————————————————————————————————————————————————————————
for fold, (train_idx, val_idx) in enumerate(skf.split(df, labels), start=1):
    print(f"\n\n########## Fold {fold}/{NUM_FOLDS} ##########")

    # a) Split DataFrame into train / validation for this fold
    train_df = df.iloc[train_idx]
    val_df   = df.iloc[val_idx]

    # b) Create Datasets and DataLoaders
    train_ds = CustomDataset(train_df, BASE_IMG_DIR, transform=transform)
    val_ds   = CustomDataset(val_df,   BASE_IMG_DIR, transform=transform)
    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True
    )

    # c) Initialize a fresh model for this fold
    model = timm.create_model(
        "hf-hub:paige-ai/Virchow2",
        pretrained=True,
        mlp_layer=SwiGLUPacked,
        act_layer=torch.nn.SiLU
    )
    model.head = nn.Linear(model.num_features, 2)
    model = model.to(DEVICE)

    # d) Loss, optimizer, scheduler, AMP scaler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=4,
        verbose=True
    )
    scaler    = torch.cuda.amp.GradScaler()

    best_val_acc = 0.0

    # e) Epoch loop
    for epoch in range(1, NUM_EPOCHS + 1):
        # — Training —
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        pbar = tqdm(train_loader, desc=f"Fold {fold} Ep {epoch}/{NUM_EPOCHS}", leave=False)

        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(images)[:, 0, :]    # use class token
                loss    = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * labels.size(0)
            _, preds     = outputs.max(1)
            correct     += preds.eq(labels).sum().item()
            total       += labels.size(0)

            pbar.set_postfix(
                loss=loss.item(),
                acc=100. * correct / total
            )

        train_acc = 100. * correct / total
        train_loss = running_loss / total

        # — Validation —
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)[:, 0, :]
                loss    = criterion(outputs, labels)

                val_loss    += loss.item() * labels.size(0)
                _, preds     = outputs.max(1)
                val_correct += preds.eq(labels).sum().item()
                val_total   += labels.size(0)

        val_acc  = 100. * val_correct / val_total
        val_loss = val_loss / val_total

        # Scheduler step on validation accuracy
        scheduler.step(val_acc)

        print(
            f"Fold {fold} | Epoch {epoch}  "
            f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%  |  "
            f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%"
        )

        # Save best model for this fold
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                model.state_dict(),
                f"virchow2_best_fold{fold}.pth"
            )
            print(f"➜ New best for fold {fold}: {best_val_acc:.2f}% saved.")

    fold_accuracies.append(best_val_acc)
    print(f"=== Fold {fold} finished. Best Val Acc: {best_val_acc:.2f}% ===")

# ————————————————————————————————————————————————————————————————
# 8. Final summary across folds
# ————————————————————————————————————————————————————————————————
avg_acc = sum(fold_accuracies) / len(fold_accuracies)
print("\n########## K‑Fold Summary ##########")
for i, acc in enumerate(fold_accuracies, start=1):
    print(f" • Fold {i}: {acc:.2f}%")
print(f" • Average Val Accuracy: {avg_acc:.2f}%")
