In [3]:
import torch
import open_clip
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from glob import glob
from torchvision import transforms
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from lightgbm import LGBMClassifier
import warnings
import albumentations as A
import os

warnings.filterwarnings('ignore')

albumentations_aug = A.Compose([
    A.OneOf([
        A.MotionBlur(p=0.3),
        A.GaussianBlur(blur_limit=5, p=0.5),
        A.MedianBlur(blur_limit=5, p=0.5),
    ], p=0.6),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Downscale(scale_min=0.7, scale_max=0.95, p=0.3),
    A.RandomResizedCrop(size=(672, 672), scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.4),
    A.GaussNoise(var_limit=(5.0, 20.0), p=0.3),
    A.ImageCompression(quality_lower=30, quality_upper=80, compression_type='jpeg', p=0.2),
])

# === 1. Load CLIP Model ===
def load_clip(device='cuda' if torch.cuda.is_available() else 'cpu'):
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    model.to(device).eval()
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    return model, preprocess, tokenizer, device

# === 2. Pupil Cropping ===
def crop_to_pupil(image_path, output_size=(224, 224)):
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray_blur = cv2.medianBlur(gray, 5)

    circles = cv2.HoughCircles(gray_blur, cv2.HOUGH_GRADIENT, dp=1.5, minDist=30,
                                param1=50, param2=30, minRadius=20, maxRadius=150)

    if circles is not None:
        circles = np.uint16(np.around(circles))
        x, y, r = circles[0][0]
        pad = int(r * 1.5)
        x1, y1 = max(0, x - pad), max(0, y - pad)
        x2, y2 = min(image.shape[1], x + pad), min(image.shape[0], y + pad)
        cropped = image[y1:y2, x1:x2]
    else:
        print(f"⚠️ Pupil not detected in {image_path}, using full image.")
        cropped = image

    resized = cv2.resize(cropped, output_size)
    return resized

# === 3. Save Cropped Images ===
def preprocess_folder(input_folder, output_folder, size=(224, 224)):
    os.makedirs(output_folder, exist_ok=True)
    for fname in tqdm(os.listdir(input_folder)):
        if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
            path = os.path.join(input_folder, fname)
            cropped = crop_to_pupil(path, output_size=size)
            save_path = os.path.join(output_folder, fname)
            cv2.imwrite(save_path, cropped)

# === 4. Augmentation (for training only) ===
augmentation_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2)
])

def augment_image(img_pil, n_augmentations=2):
    augmented = [augmentation_transforms(img_pil) for _ in range(n_augmentations)]
    grayscale = transforms.Grayscale()(img_pil)
    augmented.append(grayscale.convert("RGB"))

    np_img = np.array(img_pil)
    gray = cv2.cvtColor(np_img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 100, 200)
    edge_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    edge_pil = Image.fromarray(edge_rgb)
    augmented.append(edge_pil)
    return augmented

# === 5. Build DataFrame for Training ===
def build_dataframe(folder, label, preprocess, model, device='cuda', augment=True, n_aug=3):
    df = pd.DataFrame(columns=range(512))
    idx = 0
    for image_path in tqdm(glob(f"{folder}/*.png")):
        image = Image.open(image_path).convert("RGB")
        images = [image]
        if augment:
            np_img = np.array(image)
            for _ in range(n_aug):
                aug_img = albumentations_aug(image=np_img)['image']
                aug_img_pil = Image.fromarray(aug_img)
                images.append(aug_img_pil)
            images += augment_image(image, n_augmentations=0)

        for img in images:
            with torch.no_grad(), torch.cuda.amp.autocast():
                tensor = preprocess(img).unsqueeze(0).to(device)
                feat = model.encode_image(tensor)
                feat = feat / feat.norm(dim=-1, keepdim=True)
                df.loc[idx, list(range(512))] = feat.cpu().numpy()[0]
                df.loc[idx, 'category'] = label
                idx += 1
    return df

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import mlflow
import time
import mlflow.sklearn
from sklearn.metrics import accuracy_score, f1_score
with mlflow.start_run(run_name="CLIP_Pytotch_TL_Cataract_Classifier"):

    # Log environment info
    mlflow.set_tag("clip_model", "ViT-B-32")
    mlflow.set_tag("pytorch","ViT-B-32")
    mlflow.log_param("device", 'cpu')
    mlflow.log_param("n_augmentations", 3)
    model, preprocess, tokenizer, device = load_clip()

    # Step A: Preprocess (optional)
    preprocess_folder("processed_images/train/normal", "processed_aug_aug/train/normal")
    preprocess_folder("processed_images/train/cataract", "processed_aug_aug/train/cataract")

    # Step B: Build dataset
    start_train_time = time.time()
    df_normal = build_dataframe("processed_aug_aug/train/normal", 0, preprocess, model, device=device)
    df_cataract = build_dataframe("processed_aug_aug/train/cataract", 1, preprocess, model, device=device)
    df = pd.concat([df_normal, df_cataract]).astype(float).sample(frac=1).reset_index(drop=True)


  3%|██▉                                                                                                     | 7/246 [00:01<01:00,  3.95it/s]

KeyboardInterrupt



In [3]:
df_normal = build_dataframe("processed_aug_aug/train/normal", 0, preprocess, model, device=device)
df_cataract = build_dataframe("processed_aug_aug/train/cataract", 1, preprocess, model, device=device)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 246/246 [01:59<00:00,  2.06it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 245/245 [01:08<00:00,  3.60it/s]


In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import cv2
import numpy as np
from glob import glob

class CataractDataset(Dataset):
    def __init__(self, root_dir, transform=None, augment_fn=None, target_size=(224, 224)):
        """
        root_dir: path to 'processed_aug_aug/train'
        transform: torchvision transforms to convert to tensor and normalize
        augment_fn: albumentations pipeline (optional)
        """
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {'normal': 0, 'cataract': 1}

        for label_name in ['normal', 'cataract']:
            folder = os.path.join(root_dir, label_name)
            paths = glob(f"{folder}/*.png") + glob(f"{folder}/*.jpg") + glob(f"{folder}/*.jpeg")
            self.image_paths.extend(paths)
            self.labels.extend([self.class_to_idx[label_name]] * len(paths))

        self.transform = transform
        self.augment_fn = augment_fn
        self.target_size = target_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        # Crop to pupil
        image = crop_to_pupil(image_path, output_size=self.target_size)  # np.array in BGR

        # Convert BGR to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Apply albumentations (optional)
        if self.augment_fn:
            image = self.augment_fn(image=image)['image']

        # Convert to PIL for torchvision transforms
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, label

In [8]:
from torch.utils.data import DataLoader
from torchvision import transforms

# CLIP-normalization or standard ImageNet if using ResNet/Vit
clip_mean = [0.48145466, 0.4578275, 0.40821073]
clip_std  = [0.26862954, 0.26130258, 0.27577711]

# Final torchvision transform
final_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=clip_mean, std=clip_std)
])

# Train dataset and loader
train_dataset = CataractDataset(
    root_dir='processed_aug_aug/train',
    transform=final_transform,
    augment_fn=albumentations_aug,
    target_size=(224, 224)  # or 512x512 depending on model
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)


In [9]:
for images, labels in train_loader:
    images = images.to(device)
    labels = labels.to(device)


KeyboardInterrupt



In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import open_clip
from tqdm import tqdm

# === 1. Model Setup ===

class CLIPClassifier(nn.Module):
    def __init__(self, clip_model, num_classes=2):
        super().__init__()
        self.clip = clip_model.visual  # Only image encoder
        embed_dim = clip_model.text_projection.shape[1]
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.clip(x)  # returns image embedding (B, 512)
        return self.classifier(x)

def freeze_clip_layers(clip_visual, n_frozen_layers=10):
    """
    Freezes first n_frozen_layers of the CLIP ViT encoder.
    For ViT-B/32, total blocks = 12
    """
    for name, param in clip_visual.named_parameters():
        if 'transformer.resblocks' in name:
            block_idx = int(name.split('.')[2])
            if block_idx < n_frozen_layers:
                param.requires_grad = False
            else:
                param.requires_grad = True
        elif 'ln_post' in name or 'proj' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False  # patch embedding, pos_embed, etc.

# === 2. Training Utilities ===

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader, desc="Training", leave=False):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation", leave=False):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

# === 3. Full Training Function ===

def run_training(train_loader, val_loader, n_frozen_layers=10, num_epochs=5, lr=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load CLIP model
    clip_model, preprocess, tokenizer = open_clip.create_model_and_transforms(
        model_name='ViT-B-32',
        pretrained='laion2b_s34b_b79k'
    )
    model = CLIPClassifier(clip_model).to(device)

    # Freeze layers
    freeze_clip_layers(model.clip, n_frozen_layers=n_frozen_layers)

    # Define optimizer and loss
    criterion = nn.CrossEntropyLoss()

    head_params = model.classifier.parameters()
    clip_params = [p for p in model.clip.parameters() if p.requires_grad]

    optimizer = optim.AdamW([
        {'params': clip_params, 'lr': lr / 10},
        {'params': head_params, 'lr': lr}
    ])

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs} | Frozen CLIP layers: {n_frozen_layers}")
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.2f}%")

    return model


In [11]:
run_training(train_loader,train_loader)


Epoch 1/5 | Frozen CLIP layers: 10


Training:  38%|███████████████████████████████████▋                                                           | 6/16 [00:45<01:10,  7.07s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


Validation:  19%|█████████████████▍                                                                           | 3/16 [00:21<01:32,  7.10s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


                                                                                                                                             

Train Loss: 0.6063 | Train Acc: 65.38%
Val   Loss: 0.5089 | Val   Acc: 77.39%

Epoch 2/5 | Frozen CLIP layers: 10


Training:  81%|████████████████████████████████████████████████████████████████████████████▍                 | 13/16 [01:21<00:16,  5.49s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


Validation:  50%|██████████████████████████████████████████████▌                                              | 8/16 [00:47<00:48,  6.06s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


                                                                                                                                             

Train Loss: 0.4566 | Train Acc: 77.80%
Val   Loss: 0.4007 | Val   Acc: 82.08%

Epoch 3/5 | Frozen CLIP layers: 10


Training:  69%|████████████████████████████████████████████████████████████████▋                             | 11/16 [01:04<00:29,  5.83s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


Validation:  94%|██████████████████████████████████████████████████████████████████████████████████████▎     | 15/16 [01:24<00:05,  5.17s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


                                                                                                                                             

Train Loss: 0.3892 | Train Acc: 80.24%
Val   Loss: 0.3655 | Val   Acc: 82.69%

Epoch 4/5 | Frozen CLIP layers: 10


Training:  88%|██████████████████████████████████████████████████████████████████████████████████▎           | 14/16 [01:27<00:13,  6.61s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


Validation:  38%|██████████████████████████████████▉                                                          | 6/16 [00:57<01:37,  9.75s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


                                                                                                                                             

Train Loss: 0.3402 | Train Acc: 85.34%
Val   Loss: 0.3220 | Val   Acc: 87.58%

Epoch 5/5 | Frozen CLIP layers: 10


Training:  38%|███████████████████████████████████▋                                                           | 6/16 [00:58<01:39,  9.92s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


Validation:  75%|█████████████████████████████████████████████████████████████████████                       | 12/16 [01:44<00:29,  7.28s/it]

⚠️ Pupil not detected in processed_aug_aug/train/normal/image_83.png, using full image.


                                                                                                                                             

Train Loss: 0.2966 | Train Acc: 87.17%
Val   Loss: 0.2862 | Val   Acc: 88.80%




CLIPClassifier(
  (clip): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwis