<a href="https://colab.research.google.com/github/JohnYechanJo/Novo-Nordisk_Anomaly-Detection/blob/main/synthetic_cnv_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Synthetic CNV Image Classifier

This notebook uses CNV images from the OCT2017 dataset to generate synthetic CNV samples using a Stable Diffusion model. These synthetic images are then mixed with real CNV images at varying ratios (from 0% to 100%) to train a classifier.
The goal is to identify the optimal ratio of synthetic data that maximizes classifier performance.

**Execution Steps**:
1. Data Preprocessing: Convert OCT2017 CNV/NORMAL images into ViT embeddings → Save as `pre-trained_dataset.pt`.
2. Diffusion Model Fine-Tuning: Fine-tune the Stable Diffusion UNet on real CNV images.
3. Synthetic CNV Image Generation: Use the fine-tuned model to generate synthetic CNV images → Convert to ViT embeddings → Save as  `synthetic_cnv_dataset.pt`.
4. Classifier Training by Ratio: Mix real and synthetic CNV data at different ratios (0% to 100%) → Train classifier → Compare performance.


**Environment**: Google Colab Google Colab (GPU recommended, e.g., T4 or A100)

## 1. Setup & Installation


In [1]:
!pip install kagglehub torch torchvision transformers diffusers accelerate datasets xformers pytorch-fid pandas
import os
import gc
import torch
import numpy as np
import random
from PIL import Image
import pandas as pd
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
from transformers import ViTModel, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
from accelerate import Accelerator
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
from pytorch_fid import fid_score

# 랜덤 시드 설정
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 메모리 정리 함수
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting xformers
  Downloading xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-no

## 2. Data Preprocessing: OCT2017 Dataset → ViT Embeddings


In [2]:
def img_transform():
    return transforms.Compose([
        transforms.Lambda(lambda img: img.crop((0, 100, 768, 400))),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def vit_process(img_list, batch_size=16):
    model_name = "google/vit-base-patch16-224"
    vit_model = ViTModel.from_pretrained(model_name, output_hidden_states=True).to(device)
    vit_model.eval()
    length = len(img_list)
    layers_0, layers_1, layers_2, layers_3 = [], [], [], []
    for i in range(int(length/batch_size)):
        batch_imgs = img_list[i*batch_size:(i+1)*batch_size]
        batch_tensor = torch.stack(batch_imgs, dim=0).to(device)
        with torch.no_grad():
            outputs = vit_model(pixel_values=batch_tensor)
        hidden = outputs.hidden_states
        last_layer = hidden[12][:,0,:]
        hidden_layer_1 = torch.zeros_like(last_layer)
        hidden_layer_2 = torch.zeros_like(last_layer)
        hidden_layer_3 = torch.zeros_like(last_layer)
        for j in range(12):
            if j < 4:
                hidden_layer_1 += hidden[j][:,0,:]/4
            elif j < 8:
                hidden_layer_2 += hidden[j][:,0,:]/4
            else:
                hidden_layer_3 += hidden[j][:,0,:]/4
        layers_0.append(last_layer)
        layers_1.append(hidden_layer_1)
        layers_2.append(hidden_layer_2)
        layers_3.append(hidden_layer_3)
        clear_memory()
    return (torch.cat(layers_0, dim=0), torch.cat(layers_1, dim=0),
            torch.cat(layers_2, dim=0), torch.cat(layers_3, dim=0))

def load_trans(path, pic_num=640):
    trans_toTensor = img_transform()
    image_list = []
    i = 0
    for filename in os.listdir(path):
        if i == pic_num:
            break
        file_path = os.path.join(path, filename)
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            try:
                img = Image.open(file_path).convert("RGB")
                tensor_img = trans_toTensor(img)
                image_list.append(tensor_img)
            except Exception as e:
                print(f"Skip: {filename}, Error: {e}")
        i += 1
    return image_list

def preprocess_data():
    print("Data Preprocessing: Start!")
    import kagglehub
    path = kagglehub.dataset_download("paultimothymooney/kermany2018")
    loadpath = os.path.join(path, "OCT2017 /train")
    train_path_cnv = os.path.join(loadpath, "CNV")
    train_path_normal = os.path.join(loadpath, "NORMAL")

    cnv_tensor_list = load_trans(train_path_cnv)
    normal_tensor_list = load_trans(train_path_normal)
    out_00, out_01, out_02, out_03 = vit_process(cnv_tensor_list)
    out_10, out_11, out_12, out_13 = vit_process(normal_tensor_list)

    batch_size, batch_num, half_batch = 128, 10, 64
    tensor_list_0, tensor_list_1, tensor_list_2, tensor_list_3, labels_list = [], [], [], [], []
    for i in range(batch_num):
        cnv_tensor_0 = out_00[i*half_batch:(i+1)*half_batch]
        norm_tensor_0 = out_10[i*half_batch:(i+1)*half_batch]
        cnv_tensor_1 = out_01[i*half_batch:(i+1)*half_batch]
        norm_tensor_1 = out_11[i*half_batch:(i+1)*half_batch]
        cnv_tensor_2 = out_02[i*half_batch:(i+1)*half_batch]
        norm_tensor_2 = out_12[i*half_batch:(i+1)*half_batch]
        cnv_tensor_3 = out_03[i*half_batch:(i+1)*half_batch]
        norm_tensor_3 = out_13[i*half_batch:(i+1)*half_batch]
        tensor_0 = torch.cat((cnv_tensor_0, norm_tensor_0), dim=0)
        tensor_1 = torch.cat((cnv_tensor_1, norm_tensor_1), dim=0)
        tensor_2 = torch.cat((cnv_tensor_2, norm_tensor_2), dim=0)
        tensor_3 = torch.cat((cnv_tensor_3, norm_tensor_3), dim=0)
        labels = torch.cat([torch.zeros(half_batch, dtype=torch.long), torch.ones(half_batch, dtype=torch.long)], dim=0)
        indices = torch.randperm(batch_size)
        tensor_list_0.append(tensor_0[indices])
        tensor_list_1.append(tensor_1[indices])
        tensor_list_2.append(tensor_2[indices])
        tensor_list_3.append(tensor_3[indices])
        labels_list.append(labels[indices])
    tensor_set_0 = torch.cat(tensor_list_0, dim=0)
    tensor_set_1 = torch.cat(tensor_list_1, dim=0)
    tensor_set_2 = torch.cat(tensor_list_2, dim=0)
    tensor_set_3 = torch.cat(tensor_list_3, dim=0)
    labels_set = torch.cat(labels_list, dim=0)

    if os.path.exists('pre-trained_dataset.pt'):
        os.remove('pre-trained_dataset.pt')
    torch.save({'data_0': tensor_set_0, 'data_1': tensor_set_1, 'data_2': tensor_set_2, 'data_3': tensor_set_3, 'label': labels_set}, 'pre-trained_dataset.pt')
    print('Data Preprocessing: Done!')

preprocess_data()
clear_memory()

Data Preprocessing: Start!
Downloading from https://www.kaggle.com/api/v1/datasets/download/paultimothymooney/kermany2018?dataset_version_number=2...


100%|██████████| 10.8G/10.8G [04:30<00:00, 43.0MB/s]

Extracting files...



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Data Preprocessing: Done!


## 3. Diffusion Model Fine-Tuning

In [12]:
import kagglehub
def prepare_cnv_images():
    # Kaggle 데이터셋 다운로드
    path = kagglehub.dataset_download("paultimothymooney/kermany2018")
    print(f"Dataset downloaded to: {path}")

    # 데이터셋 경로 설정 (공백 제거 및 실제 구조 반영)
    in_dir = os.path.join(path, "OCT2017 /train/CNV")
    out_dir = "/content/processed/CNV/"
    os.makedirs(out_dir, exist_ok=True)
    for fn in os.listdir(in_dir):
        img = Image.open(os.path.join(in_dir, fn)).convert("RGB")
        img = img.resize((512, 512), resample=Image.LANCZOS)
        img.save(os.path.join(out_dir, fn))
    print(f"Processed CNV images saved to {out_dir}")

class CNVDataset(Dataset):
    def __init__(self, root_dir, tokenizer, resolution=512, max_length=77):
        self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), transforms.InterpolationMode.LANCZOS),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        self.prompt = "OCT scan showing CNV"
        self.max_length = max_length

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        img = self.transform(img)
        tokens = self.tokenizer(self.prompt, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        return {"pixel_values": img, "input_ids": tokens.input_ids.squeeze(0)}

def train_diffusion_model(pretrained_model="runwayml/stable-diffusion-v1-5", data_dir="/content/processed/CNV", output_dir="/content/sd_cnv_finetuned", resolution=512, batch_size=2, learning_rate=1e-4, epochs=5, grad_accum_steps=1, save_steps=1000, resume_checkpoint=None):
    os.makedirs(output_dir, exist_ok=True)
    accel = Accelerator()
    device = accel.device

    if resume_checkpoint and unet is None:
        unet = UNet2DConditionModel.from_pretrained(resume_checkpoint).to(device)
    else:
        unet = UNet2DConditionModel.from_pretrained(pretrained_model, subfolder="unet").to(device)

    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder").to(device)
    text_encoder.requires_grad_(False)
    vae = AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae").to(device)
    vae.requires_grad_(False)
    scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")

    dataset = CNVDataset(data_dir, tokenizer, resolution=resolution)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
    unet, optimizer, dataloader = accel.prepare(unet, optimizer, dataloader)

    if resume_checkpoint:
        accel.load_state(resume_checkpoint)
        global_step = int(resume_checkpoint.rsplit("_", 1)[-1])
    else:
        global_step = 0

    for epoch in range(1, epochs + 1):
        unet.train()
        for batch in dataloader:
            with accel.accumulate(unet):
                pixels = batch["pixel_values"].to(device)
                latents = vae.encode(pixels).latent_dist.sample() * 0.18215
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=device)
                noisy_latents = scheduler.add_noise(latents, noise, timesteps)
                input_ids = batch["input_ids"].to(device)
                encoder_hidden_states = text_encoder(input_ids)[0]
                pred_noise = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = torch.nn.functional.mse_loss(pred_noise, noise)
                accel.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
            global_step += 1
            if global_step % save_steps == 0:
                accel.wait_for_everyone()
                ckpt_dir = os.path.join(output_dir, f"checkpoint_{global_step}")
                unet.save_pretrained(ckpt_dir)
                if accel.is_main_process:
                    tokenizer.save_pretrained(ckpt_dir)
                accel.save_state(ckpt_dir)
        print(f"Epoch {epoch}/{epochs} complete")

    accel.wait_for_everyone()
    final_dir = os.path.join(output_dir, "final_unet")
    unet.save_pretrained(final_dir)
    if accel.is_main_process:
        tokenizer.save_pretrained(output_dir)
    print(f"Diffusion Model Fine-tuning Complete: Saved to {output_dir}")

prepare_cnv_images()
train_diffusion_model()
clear_memory()

Dataset downloaded to: /kaggle/input/kermany2018


Exception ignored in: <function _xla_gc_callback at 0x7faf3c4d9e40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


KeyboardInterrupt: 

## 4. Synthetic CNV Image Generation and ViT Embedding Extraction


In [10]:
def generate_synthetic_images(num_images=640, output_dir="/content/synthetic_cnv"):
    os.makedirs(output_dir, exist_ok=True)
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16
    ).to(device)
    finetuned_unet = UNet2DConditionModel.from_pretrained(
        "/content/sd_cnv_finetuned/final_unet",
        torch_dtype=torch.float16
    ).to(device)
    pipe.unet = finetuned_unet
    ckpt_tokenizer = CLIPTokenizer.from_pretrained(
        "/content/sd_cnv_finetuned"
    )
    pipe.tokenizer = ckpt_tokenizer

    images = []
    prompt = "OCT scan showing CNV"
    for i in range(num_images):
        out = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)
        img = out.images[0]
        img.save(os.path.join(output_dir, f"cnv_synthetic_{i}.png"))
        images.append(img)
        if i % 100 == 0:
            print(f"Generated {i}/{num_images} images")
    return images

def vit_process_synthetic(images, batch_size=8):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    model_name = "google/vit-base-patch16-224"
    vit_model = ViTModel.from_pretrained(model_name, output_hidden_states=True).to(device)
    vit_model.eval()
    layers_0, layers_1, layers_2, layers_3 = [], [], [], []
    for i in range(0, len(images), batch_size):
        batch_imgs = [transform(img).to(device) for img in images[i:i+batch_size]]
        batch_tensor = torch.stack(batch_imgs, dim=0)
        with torch.no_grad():
            outputs = vit_model(pixel_values=batch_tensor)
        hidden = outputs.hidden_states
        last_layer = hidden[12][:,0,:]
        hidden_layer_1 = torch.zeros_like(last_layer)
        hidden_layer_2 = torch.zeros_like(last_layer)
        hidden_layer_3 = torch.zeros_like(last_layer)
        for j in range(12):
            if j < 4:
                hidden_layer_1 += hidden[j][:,0,:]/4
            elif j < 8:
                hidden_layer_2 += hidden[j][:,0,:]/4
            else:
                hidden_layer_3 += hidden[j][:,0,:]/4
        layers_0.append(last_layer)
        layers_1.append(hidden_layer_1)
        layers_2.append(hidden_layer_2)
        layers_3.append(hidden_layer_3)
        clear_memory()
    return (torch.cat(layers_0, dim=0), torch.cat(layers_1, dim=0),
            torch.cat(layers_2, dim=0), torch.cat(layers_3, dim=0))

synthetic_images = generate_synthetic_images()
synthetic_data_0, synthetic_data_1, synthetic_data_2, synthetic_data_3 = vit_process_synthetic(synthetic_images)
torch.save({
    'data_0': synthetic_data_0,
    'data_1': synthetic_data_1,
    'data_2': synthetic_data_2,
    'data_3': synthetic_data_3,
    'label': torch.zeros(640, dtype=torch.long)
}, 'synthetic_cnv_dataset.pt')
print("Synthetic CNV Dataset Saved: synthetic_cnv_dataset.pt")
clear_memory()

OSError: [Errno 28] No space left on device: '/content/synthetic_cnv'

## 5. Define Classifier

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import abc

class TransformerBlock(nn.Module):
    def __init__(self, input_size, d_k=16, d_v=16, n_heads=8, is_layer_norm=False, attn_dropout=0):
        super(TransformerBlock, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k if d_k is not None else input_size
        self.d_v = d_v if d_v is not None else input_size
        self.is_layer_norm = is_layer_norm
        if self.is_layer_norm:
            self.layer_norm = nn.LayerNorm(normalized_shape=input_size)
        self.W_q = nn.Parameter(torch.Tensor(input_size, n_heads * d_k))
        self.W_k = nn.Parameter(torch.Tensor(input_size, n_heads * d_k))
        self.W_v = nn.Parameter(torch.Tensor(input_size, n_heads * d_v))
        self.W_o = nn.Parameter(torch.Tensor(d_v*n_heads, input_size))
        self.linear1 = nn.Linear(input_size, input_size)
        self.linear2 = nn.Linear(input_size, input_size)
        self.dropout = nn.Dropout(attn_dropout)
        self.__init_weights__()

    def __init_weights__(self):
        init.xavier_normal_(self.W_q)
        init.xavier_normal_(self.W_k)
        init.xavier_normal_(self.W_v)
        init.xavier_normal_(self.W_o)
        init.xavier_normal_(self.linear1.weight)
        init.xavier_normal_(self.linear2.weight)

    def FFN(self, X):
        output = self.linear2(F.relu(self.linear1(X)))
        output = self.dropout(output)
        return output

    def scaled_dot_product_attention(self, Q, K, V, episilon=1e-6):
        temperature = self.d_k ** 0.5
        Q_K = torch.einsum("bqd,bkd->bqk", Q, K) / (temperature + episilon)
        Q_K_score = F.softmax(Q_K, dim=-1)
        Q_K_score = self.dropout(Q_K_score)
        V_att = Q_K_score.bmm(V)
        return V_att

    def multi_head_attention(self, Q, K, V):
        bsz, q_len, _ = Q.size()
        bsz, k_len, _ = K.size()
        bsz, v_len, _ = V.size()
        Q_ = Q.matmul(self.W_q).view(bsz, q_len, self.n_heads, self.d_k)
        K_ = K.matmul(self.W_k).view(bsz, k_len, self.n_heads, self.d_k)
        V_ = V.matmul(self.W_v).view(bsz, k_len, self.n_heads, self.d_v)
        Q_ = Q_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_k)
        K_ = K_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, k_len, self.d_k)
        V_ = V_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, k_len, self.d_v)
        V_att = self.scaled_dot_product_attention(Q_, K_, V_)
        V_att = V_att.view(bsz, self.n_heads, q_len, self.d_v)
        V_att = V_att.permute(0, 2, 1, 3).contiguous().view(bsz, q_len, self.n_heads*self.d_v)
        output = self.dropout(V_att.matmul(self.W_o))
        return output

    def forward(self, Q, K, V):
        V_att = self.multi_head_attention(Q, K, V)
        if self.is_layer_norm:
            X = self.layer_norm(Q + V_att)
            output = self.layer_norm(self.FFN(X) + X)
        else:
            X = Q + V_att
            output = self.FFN(X) + X
        return output

class EncoderBlock(nn.Module):
    def __init__(self, input_dim=768, output_dim=300, hidden_dim_1=300, hidden_dim_2=450, attn_drop=0.15):
        super(EncoderBlock, self).__init__()
        self.attn_drop = attn_drop
        dataset_dic = torch.load('pre-trained_dataset.pt')
        embedding_weights_0 = dataset_dic['data_0']
        embedding_weights_1 = dataset_dic['data_1']
        embedding_weights_2 = dataset_dic['data_2']
        embedding_weights_3 = dataset_dic['data_3']
        total = 1280
        self.embedding_layer_0 = nn.Embedding(num_embeddings=total, embedding_dim=input_dim, padding_idx=0, _weight=embedding_weights_0)
        self.embedding_layer_1 = nn.Embedding(num_embeddings=total, embedding_dim=input_dim, padding_idx=0, _weight=embedding_weights_1)
        self.embedding_layer_2 = nn.Embedding(num_embeddings=total, embedding_dim=input_dim, padding_idx=0, _weight=embedding_weights_2)
        self.embedding_layer_3 = nn.Embedding(num_embeddings=total, embedding_dim=input_dim, padding_idx=0, _weight=embedding_weights_3)
        self.linear_1 = nn.Linear(input_dim, hidden_dim_1)
        self.linear_2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.linear_3 = nn.Linear(hidden_dim_2, output_dim)
        self.dropout = nn.Dropout(attn_drop)
        self.relu = nn.ReLU()
        self.__init_weights__()

    def __init_weights__(self):
        init.xavier_normal_(self.linear_1.weight)
        init.xavier_normal_(self.linear_2.weight)
        init.xavier_normal_(self.linear_3.weight)

    def forward(self, layer_id=0, X_id=0):
        if torch.is_tensor(X_id):
            X_id = X_id.to(device)
            if layer_id == 0:
                X_ = self.embedding_layer_0(X_id).to(torch.float32)
            elif layer_id == 1:
                X_ = self.embedding_layer_1(X_id).to(torch.float32)
            elif layer_id == 2:
                X_ = self.embedding_layer_2(X_id).to(torch.float32)
            elif layer_id == 3:
                X_ = self.embedding_layer_3(X_id).to(torch.float32)
        else:
            print("Non-standard use of encoderblock!")
        residual = self.relu(self.linear_1(X_))
        x_ = self.relu(self.dropout(self.linear_2(residual)))
        x_ = self.linear_3(x_) + residual
        return x_

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.best_acc = 0
        self.init_clip_max_norm = None

    @abc.abstractmethod
    def forward(self):
        pass

    def fit(self, x_train, y_train, x_val, y_val, x_test, y_test):
        if torch.cuda.is_available():
            self.cuda()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=8e-5, weight_decay=0)
        dataset = TensorDataset(x_train, y_train)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
        loss = nn.CrossEntropyLoss()
        epochs = 15
        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            self.train()
            for i, data in enumerate(dataloader):
                total = len(dataloader)
                batch_x_id, batch_y = (item.cuda() for item in data)
                self.batch_dealer(batch_x_id, batch_y, loss, i, epoch+1, total)
            self.batch_evaluate(x_val, y_val)

    def batch_dealer(self, x_id, y, loss, i, epoch, total):
        self.optimizer.zero_grad()
        logit_original = self.forward(x_id, epoch=epoch)
        loss_classify = loss(logit_original, y)
        loss_classify.backward()
        self.optimizer.step()
        corrects = (torch.max(logit_original, 1)[1].view(y.size()).data == y.data).sum()
        accuracy = 100 * corrects / len(y)
        print(f'Batch[{i + 1}/{total}] - loss: {loss_classify.item():.6f}  accuracy: {accuracy:.4f}%({corrects}/{y.size(0)})')

    def batch_evaluate(self, x, y):
        y_pred = self.predicter(x)
        acc = accuracy_score(y, y_pred)
        if acc > self.best_acc:
            self.best_acc = acc
        print(classification_report(y, y_pred, target_names=['NR', 'FR'], digits=5))
        print("Val set acc:", acc)
        print("Best val set acc:", self.best_acc)

    def predicter(self, x):
        if torch.cuda.is_available():
            self.cuda()
        self.eval()
        y_pred = []
        dataset = TensorDataset(x)
        dataloader = DataLoader(dataset, batch_size=16)
        for i, data in enumerate(dataloader):
            with torch.no_grad():
                batch_x_id = data[0].cuda()
                logits = self.forward(batch_x_id)
                predicted = torch.max(logits, dim=1)[1]
                y_pred += predicted.data.cpu().numpy().tolist()
        return y_pred

class Classifier(NeuralNetwork):
    def __init__(self):
        super().__init__()
        self.encoder_block = EncoderBlock()
        self.attention = TransformerBlock(input_size=300)
        self.dropout = nn.Dropout(0.6)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(1200, 300)
        self.fc1 = nn.Linear(300, 600)
        self.fc2 = nn.Linear(600, 300)
        self.fc3 = nn.Linear(in_features=300, out_features=2)
        self.init_weight()

    def init_weight(self):
        init.xavier_normal_(self.fc.weight)
        init.xavier_normal_(self.fc1.weight)
        init.xavier_normal_(self.fc2.weight)
        init.xavier_normal_(self.fc3.weight)

    def forward(self, x_id, epoch=0):
        batch_size = x_id.shape[0]
        x_id = x_id.cuda()
        embedding_0 = self.encoder_block(layer_id=0, X_id=x_id)
        embedding_1 = self.encoder_block(layer_id=1, X_id=x_id)
        embedding_2 = self.encoder_block(layer_id=2, X_id=x_id)
        embedding_3 = self.encoder_block(layer_id=3, X_id=x_id)
        embedding = self.relu(self.fc(torch.cat((embedding_0, embedding_1, embedding_2, embedding_3), dim=1)))
        enhanced = self.attention(embedding.view(batch_size, 1, 300), embedding.view(batch_size, 1, 300), embedding.view(batch_size, 1, 300))
        enhanced = enhanced.squeeze(1)
        a1 = self.relu(self.dropout(self.fc1(enhanced)))
        a1 = self.relu(self.dropout(self.fc2(a1)))
        output = self.fc3(a1)
        return output

def train_and_test(model, x_train, y_train, x_val, y_val, x_test, y_test):
    nn = model
    nn.fit(x_train, y_train, x_val, y_val, x_test, y_test)
    y_pred = nn.predicter(x_test)
    res = classification_report(y_test, y_pred, target_names=['NR', 'FR'], digits=3, output_dict=True)
    for k, v in res.items():
        print(k, v)
    print(f"result: {res['accuracy']:.4f}")
    return res

## 6. Classifier Training and Result Analysis by Synthetic Data Ratio


In [None]:
def train_classifier_with_ratios():
    normal_dataset = torch.load('pre-trained_dataset.pt')
    synthetic_dataset = torch.load('synthetic_cnv_dataset.pt')

    normal_cnv_indices = normal_dataset['label'] == 0
    normal_normal_indices = normal_dataset['label'] == 1
    normal_cnv_data = {
        'data_0': normal_dataset['data_0'][normal_cnv_indices],
        'data_1': normal_dataset['data_1'][normal_cnv_indices],
        'data_2': normal_dataset['data_2'][normal_cnv_indices],
        'data_3': normal_dataset['data_3'][normal_cnv_indices],
        'label': normal_dataset['label'][normal_cnv_indices]
    }
    normal_normal_data = {
        'data_0': normal_dataset['data_0'][normal_normal_indices],
        'data_1': normal_dataset['data_1'][normal_normal_indices],
        'data_2': normal_dataset['data_2'][normal_normal_indices],
        'data_3': normal_dataset['data_3'][normal_normal_indices],
        'label': normal_dataset['label'][normal_normal_indices]
    }
    synthetic_cnv_data = synthetic_dataset

    ratios = np.arange(0, 1.1, 0.1)
    results = []

    for ratio in ratios:
        print(f"\nTraining with Synthetic Ratio: {ratio*100:.0f}%")
        num_synthetic = int(640 * ratio)
        num_normal_cnv = 640 - num_synthetic
        indices_synthetic = np.random.choice(640, num_synthetic, replace=False)
        indices_normal_cnv = np.random.choice(640, num_normal_cnv, replace=False)

        mixed_data_0 = torch.cat([
            synthetic_cnv_data['data_0'][indices_synthetic],
            normal_cnv_data['data_0'][indices_normal_cnv],
            normal_normal_data['data_0']
        ], dim=0)
        mixed_data_1 = torch.cat([
            synthetic_cnv_data['data_1'][indices_synthetic],
            normal_cnv_data['data_1'][indices_normal_cnv],
            normal_normal_data['data_1']
        ], dim=0)
        mixed_data_2 = torch.cat([
            synthetic_cnv_data['data_2'][indices_synthetic],
            normal_cnv_data['data_2'][indices_normal_cnv],
            normal_normal_data['data_2']
        ], dim=0)
        mixed_data_3 = torch.cat([
            synthetic_cnv_data['data_3'][indices_synthetic],
            normal_cnv_data['data_3'][indices_normal_cnv],
            normal_normal_data['data_3']
        ], dim=0)
        mixed_labels = torch.cat([
            synthetic_cnv_data['label'][indices_synthetic],
            normal_cnv_data['label'][indices_normal_cnv],
            normal_normal_data['label']
        ], dim=0)

        x_train = torch.arange(0, 1024)
        x_val = torch.arange(1024, 1152)
        x_test = torch.arange(1152, 1280)
        y_train = mixed_labels[:1024]
        y_val = mixed_labels[1024:1152]
        y_test = mixed_labels[1152:1280]

        # 임베딩 저장
        if os.path.exists('pre-trained_dataset.pt'):
            os.remove('pre-trained_dataset.pt')
        torch.save({
            'data_0': mixed_data_0,
            'data_1': mixed_data_1,
            'data_2': mixed_data_2,
            'data_3': mixed_data_3,
            'label': mixed_labels
        }, 'pre-trained_dataset.pt')

        model = Classifier()
        res = train_and_test(model, x_train, y_train, x_val, y_val, x_test, y_test)

        results.append({
            'ratio': ratio,
            'accuracy': res['accuracy'],
            'f1_score': res['macro avg']['f1-score'],
            'precision': res['macro avg']['precision'],
            'recall': res['macro avg']['recall']
        })
        clear_memory()

    print("\nResults Summary:")
    for res in results:
        print(f"Ratio: {res['ratio']*100:.0f}% | Accuracy: {res['accuracy']:.4f} | F1 Score: {res['f1_score']:.4f} | Precision: {res['precision']:.4f} | Recall: {res['recall']:.4f}")

    best_result = max(results, key=lambda x: x['accuracy'])
    print(f"\nBest Ratio: {best_result['ratio']*100:.0f}%")
    print(f"Accuracy: {best_result['accuracy']:.4f}")
    print(f"F1 Score: {best_result['f1_score']:.4f}")
    print(f"Precision: {best_result['precision']:.4f}")
    print(f"Recall: {best_result['recall']:.4f}")

    pd.DataFrame(results).to_csv('classifier_results.csv', index=False)
    print("Results saved to classifier_results.csv")

train_classifier_with_ratios()

## 7. Quality Analysis of Synthetic Images (FID Score)


In [None]:
def calculate_fid():
    fid = fid_score.calculate_fid_given_paths(
        ['/content/processed/CNV', '/content/synthetic_cnv'],
        batch_size=50,
        device='cuda',
        dims=2048
    )
    print(f"FID Score: {fid}")

calculate_fid()