In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import clip

In [4]:
DATA_ROOT = "datasets/clean_data"
MODEL_NAME = "ViT-B/32"
BATCH_SIZE = 64
EPOCHS = 20
PATIENCE = 5
LR_HEAD = 1e-3
NUM_WORKERS = 4
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [5]:
model, preprocess = clip.load(MODEL_NAME, device=device)

train_ds = ImageFolder(os.path.join(DATA_ROOT, "train"), transform=preprocess)
val_ds   = ImageFolder(os.path.join(DATA_ROOT, "val"),   transform=preprocess)
test_ds  = ImageFolder(os.path.join(DATA_ROOT, "test"),  transform=preprocess)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print("Classes:", train_ds.classes)

Classes: ['medical', 'non_medical']


In [6]:
from torch.cuda.amp import GradScaler, autocast

use_amp = torch.cuda.is_available()
scaler = GradScaler(enabled=use_amp)

for p in model.parameters():
    p.requires_grad = False

feat_dim = model.visual.output_dim
classifier = nn.Linear(feat_dim, 2).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=LR_HEAD)


use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

@torch.no_grad()
def encode_images(x):
    
    with torch.cuda.amp.autocast(enabled=use_amp):
        feats = model.encode_image(x).float()
    return feats


  scaler = GradScaler(enabled=use_amp)
  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


In [7]:
def evaluate(loader):
    classifier.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    all_y, all_p = [], []
    with torch.no_grad():
        for imgs, ys in loader:
            imgs, ys = imgs.to(device), ys.to(device)
            feats = encode_images(imgs)
            logits = classifier(feats)
            loss = criterion(logits, ys)
            total_loss += loss.item() * ys.size(0)
            preds = logits.argmax(1)
            total_correct += (preds == ys).sum().item()
            total += ys.size(0)
            all_y.append(ys.cpu().numpy())
            all_p.append(preds.cpu().numpy())
    avg_loss = total_loss / max(1, total)
    acc = total_correct / max(1, total)
    all_y = np.concatenate(all_y) if all_y else np.array([])
    all_p = np.concatenate(all_p) if all_p else np.array([])
    return avg_loss, acc, all_y, all_p

best_val = float("inf")
best_state = None
wait = 0

print("Starting training...")
for epoch in range(1, EPOCHS+1):
    classifier.train()
    running = 0.0
    for imgs, ys in train_loader:
        imgs, ys = imgs.to(device), ys.to(device)
        feats = encode_images(imgs)
        optimizer.zero_grad(set_to_none=True)
        if use_amp:
            with torch.cuda.amp.autocast():
                logits = classifier(feats)
                loss = criterion(logits, ys)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = classifier(feats)
            loss = criterion(logits, ys)
            loss.backward()
            optimizer.step()
        running += loss.item() * ys.size(0)

    train_loss = running / max(1, len(train_ds))
    val_loss, val_acc, _, _ = evaluate(val_loader)
    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        best_state = classifier.state_dict()
        wait = 0
    else:
        wait += 1
        if wait >= PATIENCE:
            print("Early stopping.")
            break

# Load best classifier
classifier.load_state_dict(best_state)


Starting training...


  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast():


Epoch 01 | train_loss=0.1803 | val_loss=0.0348 | val_acc=1.0000
Epoch 02 | train_loss=0.0236 | val_loss=0.0154 | val_acc=1.0000
Epoch 03 | train_loss=0.0125 | val_loss=0.0092 | val_acc=1.0000
Epoch 04 | train_loss=0.0081 | val_loss=0.0064 | val_acc=1.0000
Epoch 05 | train_loss=0.0058 | val_loss=0.0047 | val_acc=1.0000
Epoch 06 | train_loss=0.0044 | val_loss=0.0037 | val_acc=1.0000
Epoch 07 | train_loss=0.0035 | val_loss=0.0030 | val_acc=1.0000
Epoch 08 | train_loss=0.0029 | val_loss=0.0025 | val_acc=1.0000
Epoch 09 | train_loss=0.0024 | val_loss=0.0022 | val_acc=1.0000
Epoch 10 | train_loss=0.0020 | val_loss=0.0019 | val_acc=1.0000
Epoch 11 | train_loss=0.0018 | val_loss=0.0017 | val_acc=1.0000
Epoch 12 | train_loss=0.0015 | val_loss=0.0015 | val_acc=1.0000
Epoch 13 | train_loss=0.0013 | val_loss=0.0013 | val_acc=1.0000
Epoch 14 | train_loss=0.0012 | val_loss=0.0012 | val_acc=1.0000
Epoch 15 | train_loss=0.0011 | val_loss=0.0011 | val_acc=1.0000
Epoch 16 | train_loss=0.0010 | val_loss=

<All keys matched successfully>

In [8]:

test_loss, test_acc, y_true, y_pred = evaluate(test_loader)
print(f"Test: loss={test_loss:.4f}, acc={test_acc:.4f}")
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, target_names=train_ds.classes))
print(train_ds.classes)

  with torch.cuda.amp.autocast(enabled=use_amp):


Test: loss=0.0006, acc=1.0000
Confusion matrix:
 [[372   0]
 [  0 372]]
              precision    recall  f1-score   support

     medical       1.00      1.00      1.00       372
 non_medical       1.00      1.00      1.00       372

    accuracy                           1.00       744
   macro avg       1.00      1.00      1.00       744
weighted avg       1.00      1.00      1.00       744

['medical', 'non_medical']


In [9]:
val_loss, val_acc, y_true, y_pred = evaluate(val_loader)
print(f"Test: loss={val_loss:.4f}, acc={val_acc:.4f}")
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, target_names=train_ds.classes))
print(train_ds.classes)

  with torch.cuda.amp.autocast(enabled=use_amp):


Test: loss=0.0008, acc=1.0000
Confusion matrix:
 [[370   0]
 [  0 370]]
              precision    recall  f1-score   support

     medical       1.00      1.00      1.00       370
 non_medical       1.00      1.00      1.00       370

    accuracy                           1.00       740
   macro avg       1.00      1.00      1.00       740
weighted avg       1.00      1.00      1.00       740

['medical', 'non_medical']


In [11]:
save_path = "model/clip_finetuned.pt"
meta = {
    "model_name": MODEL_NAME,                 # e.g., "ViT-B/32"
    "classes": train_ds.classes,              # ["medical","non_medical"]
    "preprocess_info": "use clip.load preprocess",  # reminder
    "seed": SEED,
}
torch.save({
    "classifier_state": classifier.state_dict(),
    "meta": meta,
}, save_path)
print("Saved:", save_path)

Saved: model/clip_finetuned.pt


In [None]:
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"


checkpoint = torch.load("model/clip_finetuned.pt", map_location=device)
model, preprocess = clip.load(checkpoint["meta"]["model_name"], device=device)
for p in model.parameters(): p.requires_grad = False

feat_dim = model.visual.output_dim
classifier = torch.nn.Linear(feat_dim, len(checkpoint["meta"]["classes"])).to(device)
classifier.load_state_dict(checkpoint["classifier_state"])
classifier.eval()


def predict_image(image_path):
    img = preprocess(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        feats = model.encode_image(img).float()
        logits = classifier(feats)
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    cls = checkpoint["meta"]["classes"]
    return dict(zip(cls, probs))
print(predict_image("image copy.png"))

{'medical': np.float32(0.9980819), 'non_medical': np.float32(0.0019181061)}
