<a href="https://colab.research.google.com/github/JohnYechanJo/Novo-Nordisk_Anomaly-Detection/blob/classifier/Final_3_times.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Synthetic CNV Image Classifier
This notebook uses CNV images from the OCT2017 dataset to generate synthetic CNV images using a Stable Diffusion model. The synthetic images are then mixed with Normal CNV images at various ratios (0% to 100%) to train a classifier. The goal is to find the optimal ratio of synthetic images to maximize classifier performance.

Execution Steps:

1. Data Preprocessing: OCT2017 CNV/NORMAL images → ViT embeddings → pre-trained_dataset.pt.

2. Diffusion Model Fine-tuning: Fine-tune Stable Diffusion UNet with CNV images.

3. Synthetic CNV Image Generation: Generate images using the fine-tuned model → ViT embeddings → synthetic_cnv_dataset.pt.

4. Classifier Training by Ratio: Mix data at synthetic ratios from 0% to 100% → Train classifier → Compare performance.

Execution Environment: Google Colab (GPU, e.g., T4 or A100 recommended).

##1. Environment Setup and Package Installation
Install the required Python packages and set the random seed to ensure reproducibility. GPU will be utilized, and functions for memory management will also be defined

In [1]:
!pip install kagglehub torch torchvision transformers diffusers accelerate datasets xformers pytorch-fid pandas
import os
import gc
import torch
import numpy as np
import random
from PIL import Image
import pandas as pd
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
from transformers import ViTModel, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
from accelerate import Accelerator
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
from pytorch_fid import fid_score

# Set random seed
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

# GPU setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Memory cleanup function
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()



In [2]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import os
import random
from PIL import Image
from torchvision import transforms


device = 'cuda' if torch.cuda.is_available() else 'cpu'


def img_transform():
    return transforms.Compose([
        transforms.Lambda(lambda img: img.crop((0, 100, 768, 400))),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def load_trans(path, pic_num=640):
    trans_toTensor = img_transform()
    image_list = []
    i = 0
    for filename in os.listdir(path):
        if i == pic_num:
            break
        file_path = os.path.join(path, filename)
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            try:
                img = Image.open(file_path).convert("RGB")
                tensor_img = trans_toTensor(img)
                image_list.append(tensor_img)
            except Exception as e:
                print(f"Skip: {filename}, Error: {e}")
        i += 1
    return image_list

class OCTDataset(Dataset):
    def __init__(self, root_dir, label, transform):
        self.paths = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.jpeg') or f.endswith('.jpg')]
        self.label = label
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        return img, self.label


class BalancedBatchSampler(Sampler):
    def __init__(self, syn_cnv_len, real_cnv_len, norm_len, batch_size=128, syn_cnv_ratio=0.25, num_batches=30,seed=123):
        self.syn_cnv_len = syn_cnv_len
        self.real_cnv_len = real_cnv_len
        self.norm_len = norm_len
        self.batch_size = batch_size
        self.syn_cnv_ratio = syn_cnv_ratio
        self.num_batches = num_batches
        self.seed = seed
        self.cnv_batch = batch_size // 2
        self.norm_batch = batch_size // 2
        self.syn_cnv_batch = int(self.cnv_batch * syn_cnv_ratio)
        self.real_cnv_batch = self.cnv_batch - self.syn_cnv_batch

        total_syn_needed = self.syn_cnv_batch * num_batches
        total_real_cnv_needed = self.real_cnv_batch * num_batches
        total_norm_needed = self.norm_batch * num_batches

        assert self.syn_cnv_len >= total_syn_needed
        assert self.real_cnv_len >= total_real_cnv_needed
        assert self.norm_len >= total_norm_needed
        rng = random.Random(seed)
        self.syn_indices = rng.sample(range(self.syn_cnv_len), total_syn_needed)
        self.real_cnv_indices = rng.sample(range(self.real_cnv_len), total_real_cnv_needed)
        self.norm_indices = rng.sample(range(self.norm_len), total_norm_needed)

    def __iter__(self):
        for i in range(self.num_batches):
            syn_idx = self.syn_indices[i*self.syn_cnv_batch : (i+1)*self.syn_cnv_batch]
            real_cnv_idx = self.real_cnv_indices[i*self.real_cnv_batch : (i+1)*self.real_cnv_batch]
            norm_idx = self.norm_indices[i*self.norm_batch : (i+1)*self.norm_batch]
            yield syn_idx, real_cnv_idx, norm_idx

    def __len__(self):
        return self.num_batches



def build_dataloader(syn_dir, cnv_dir, norm_dir, batch_size=128, syn_ratio=0.5):
    syn_cnv_dataset = OCTDataset(syn_dir,1,img_transform())
    real_cnv_dataset = OCTDataset(cnv_dir,1,img_transform())
    norm_dataset = OCTDataset(norm_dir,0,img_transform())

    sampler = BalancedBatchSampler(
        syn_cnv_len=len(syn_cnv_dataset),
        real_cnv_len=len(real_cnv_dataset),
        norm_len=len(norm_dataset),
        batch_size=batch_size,
        syn_cnv_ratio=syn_ratio
    )

    def collate_fn(index_tuple):
        syn_idx, cnv_idx, norm_idx = index_tuple
        syn_batch = [syn_cnv_dataset[i] for i in syn_idx]
        cnv_batch = [real_cnv_dataset[i] for i in cnv_idx]
        norm_batch = [norm_dataset[i] for i in norm_idx]

        batch = syn_batch + cnv_batch + norm_batch
        random.shuffle(batch)
        imgs, labels = zip(*batch)
        return torch.stack(imgs), torch.tensor(labels)


    dummy_dataset = list(range(len(sampler)))
    loader = DataLoader(dummy_dataset, batch_size=1, collate_fn=collate_fn,num_workers=2, pin_memory=True)
    return loader


In [2]:
import torch.nn as nn


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),  # (3,224,224) -> (16,224,224)
            nn.ReLU(),
            nn.MaxPool2d(2),                            # -> (16,112,112)

            nn.Conv2d(16, 32, kernel_size=3, padding=1), # -> (32,112,112)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # -> (32,56,56)

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # -> (64,56,56)
            nn.ReLU(),
            nn.MaxPool2d(2)                              # -> (64,28,28)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),             # -> (64×28×28)
            nn.Linear(64*28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 2)         # 2-class classification
        )

    def forward(self, x):
        x = self.conv_block(x)
        x = self.fc(x)
        return x
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = SimpleCNN().cuda()
        self.best_acc = 0
    def train_val_test(self,syn_dir, cnv_dir, norm_dir):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3, weight_decay = 1e-4)
        # unpack
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if os.path.exists('Dataset.pt'):
         data_dic = torch.load('Dataset.pt')
         mixed_val = data_dic['val_data'].to(device)
         mixed_val_label = data_dic['val_label'].to(device)
         mixed_test = data_dic['test_data'].to(device)
         mixed_test_label = data_dic['test_label'].to(device)
        train_loader = build_dataloader(syn_dir, cnv_dir, norm_dir)
        loss = nn.CrossEntropyLoss()
        epochs = 15
        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            # train
            self.model.train()
            for i, data in enumerate(train_loader):
                total = len(train_loader)
                batch_x, batch_y = (item.cuda() for item in data)
                self.optimizer.zero_grad()
                logit_original = self.model(batch_x)
                l = loss(logit_original, batch_y)
                l.backward()
                self.optimizer.step()
                corrects = (torch.max(logit_original, 1)[1].view(batch_y.size()).data == batch_y.data).sum()
                accuracy = 100 * corrects / len(batch_y)
                print(f'Batch[{i + 1}/{total}] - loss: {l.item():.6f}  accuracy: {accuracy:.4f}%({corrects}/{batch_y.size(0)})')
            # val
            self.model.eval()
            with torch.no_grad():
                logits = self.model(mixed_val)
                predicted = torch.max(logits,dim=1)[1]
                y_pred = predicted.data.cpu().numpy().tolist()
                acc = accuracy_score(mixed_val_label.cpu().numpy().tolist(), y_pred)
                print(f"Validation Accuracy: {acc:.4f}")
                if acc > self.best_acc:
                   self.best_acc = acc
                print("Best val set acc:", self.best_acc)
        # test
        self.model.eval()
        with torch.no_grad():
            logits = self.model(mixed_test)
            predicted = torch.max(logits,dim=1)[1]
            y_pred = predicted.data.cpu().numpy().tolist()
            try:
              res = classification_report(mixed_test_label.cpu().numpy().tolist(), y_pred, labels=[0, 1], target_names=['NR', 'FR'], digits=3, output_dict=True)
              for k, v in res.items():
                  print(k, v)
              print(f"result: {res['accuracy']:.4f}")
            except ValueError as e:
              print(f"Error in classification_report: {e}")
              res = {'accuracy': 0, 'macro avg': {'f1-score': 0, 'precision': 0, 'recall': 0}}
        return res




def runable(model,syn_dir, cnv_dir, norm_dir):
    nn = model
    res = nn.train_val_test(syn_dir, cnv_dir, norm_dir)
    return res


In [1]:
import torch
import numpy as np
import pandas as pd
import os
import kagglehub
from google.colab import drive
drive.mount('/content/drive')
synthetic_path = "/content/drive/MyDrive/synthetic_cnv_merged"

def train_classifier_with_ratios():
    results = []
    # download dataset
    path = kagglehub.dataset_download("paultimothymooney/kermany2018")
    loadpath = os.path.join(path, "OCT2017 /train")
    loadpath_1 = os.path.join(path, "OCT2017 /test")
    train_path_cnv = os.path.join(loadpath, "CNV")
    train_path_normal = os.path.join(loadpath, "NORMAL")
    # 120:120 *2
    tv_path_cnv = os.path.join(loadpath, "CNV")
    tv_path_normal = os.path.join(loadpath, "NORMAL")
    cnv_tensor_list = load_trans(tv_path_cnv, pic_num = 240)
    normal_tensor_list = load_trans(tv_path_normal,  pic_num = 240)
    cnv_tensor_list = torch.stack(cnv_tensor_list)
    normal_tensor_list = torch.stack(normal_tensor_list)
    # ratio : 0~1
    ratios = [i/10 for i in range(11)]
    for ratio in ratios:

        #get val dataset
        i, j = 0,0
        val_labels = torch.cat([torch.ones(120, dtype=torch.long),
                              torch.zeros(120, dtype=torch.long)], dim=0)
        n_cnv = cnv_tensor_list[i:i+120]
        n_norm = normal_tensor_list[j:j+120]
        mixed_val = torch.cat([n_cnv, n_norm],dim=0)
        print(i)
        #get test dataset
        i, j=120,120
        test_labels = torch.cat([torch.ones(120, dtype=torch.long),
                              torch.zeros(120, dtype=torch.long)], dim=0)
        n_cnv = cnv_tensor_list[i:i+120]
        n_norm = normal_tensor_list[j:j+120]
        mixed_test = torch.cat([n_cnv, n_norm],dim=0)

        # save norm->0 / cnv->1
        if os.path.exists('Dataset.pt'):
           os.remove('Dataset.pt')
        torch.save({
            "val_data": mixed_val,
            "val_label": val_labels,
            "test_data": mixed_test,
            "test_label": test_labels
        },"Dataset.pt")
        clear_memory()

        # Train and evaluate model
        print(f"Training model with ratio {ratio}")
        model = Classifier()  # Assumes Classifier is defined elsewhere
        res = runable(model,synthetic_path,train_path_cnv,train_path_normal)  # Assumes train_and_test is defined
        results.append({
            'ratio': ratio,
            'accuracy': res['accuracy'],
            'f1_score': res['macro avg']['f1-score'],
            'precision': res['macro avg']['precision'],
            'recall': res['macro avg']['recall']
        })
        clear_memory()  # Assumes clear_memory is defined elsewhere


    print("\nResults for Ratio 0 to 1:")
    for res in results:
        if res['accuracy'] is not None:
            print(f"Ratio: {res['ratio']*100:.0f}% | Accuracy: {res['accuracy']:.4f} | F1 Score: {res['f1_score']:.4f} | Precision: {res['precision']:.4f} | Recall: {res['recall']:.4f}")


    pd.DataFrame(results).to_csv('classifier_results.csv', index=False)
    print("Results saved to classifier_results.csv")

train_classifier_with_ratios()

Mounted at /content/drive


NameError: name 'load_trans' is not defined