In [None]:
!git clone https: // github.com/Hanson0910/Pytorch-RIADD.git

In [None]:
import sys

import seaborn as sns
import tenseal as ts

# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '/content/Pytorch-RIADD')
import os
import pickle
import lightning as L
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import albumentations
from albumentations.pytorch import ToTensorV2
import cv2
from glob import glob
from imblearn.over_sampling import RandomOverSampler
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

import pandas as pd

import torch
from torch import nn

In [None]:

image_size = 64
train_trans = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    ToTensorV2(),
])

train_trans_batch = albumentations.Compose([
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.MedianBlur(blur_limit=7, p=0.3),
    albumentations.IAAAdditiveGaussianNoise(scale=(0, 0.15 * 255), p=0.5),
    albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.3),
    albumentations.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.3),
    albumentations.Cutout(max_h_size=20, max_w_size=20, num_holes=5, p=0.5),
    albumentations.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

valid_trans_batch = albumentations.Compose([
    albumentations.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

In [None]:
class RetinaDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, data_folder, label_path, upsample, presaved_data_path, save_data_path, transform, image_size):

        self.simple_transform = albumentations.Compose([
            albumentations.Resize(image_size, image_size),
        ])

        self.transform = transform

        if presaved_data_path == None:
            label_frame = pd.read_csv(label_path)
            for image_name in glob(data_folder + "/*"):
                label_frame_index = int(image_name.split("/")[-1].split(".")[0])
                label_frame.loc[label_frame["ID"] == label_frame_index, "image_path"] = image_name

            label_frame = label_frame[~label_frame["image_path"].isna()]

            X_train, y_train = label_frame["image_path"].values.reshape(-1, 1), label_frame.pop("Disease_Risk")
            if upsample == True:
                ros = RandomOverSampler(random_state=0)
                X_train, y_train = ros.fit_resample(X_train, y_train)

            self.data = []
            for image, label in tqdm(zip(X_train.flatten(), y_train)):
                img = cv2.imread(image)
                img = self.simple_transform(image=img)["image"]
                self.data.append((img, label))

            with open(save_data_path, "wb") as f:
                pickle.dump(self.data, f)

        else:
            with open(presaved_data_path, "rb") as f:
                self.data = pickle.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.transform(image=self.data[idx][0])["image"], self.data[idx][1]

In [None]:
train_dataset = RetinaDataset(data_folder="/content/Training_Set/Training",
                              label_path="/content/Training_Set/RFMiD_Training_Labels.csv", upsample=True,
                              presaved_data_path="/content/drive/MyDrive/CapstoneData/train_data_upsampe.pkl",
                              save_data_path=None, transform=train_trans_batch, image_size=64)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)

In [None]:
val_dataset = RetinaDataset(data_folder="/content/Evaluation_Set/Validation",
                            label_path="/content/Evaluation_Set/RFMiD_Validation_Labels.csv", upsample=False,
                            presaved_data_path="/content/drive/MyDrive/CapstoneData/validation_data_upsampe.pkl",
                            save_data_path=None, transform=valid_trans_batch, image_size=64)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=8)

In [None]:
test_dataset = RetinaDataset(data_folder="/content/Test_Set/Test",
                             label_path="/content/Test_Set/RFMiD_Testing_Labels.csv", upsample=False,
                             presaved_data_path="/content/drive/MyDrive/CapstoneData/test_data_upsampe.pkl",
                             save_data_path=None, transform=valid_trans_batch, image_size=64)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=8)

In [None]:


# Setting the seed
L.seed_everything(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
# DATASET_PATH = "data/"
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "/content/drive/MyDrive/CapstoneData/")

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
    return x

In [None]:
class EncLayerNorm(nn.Module):

    def __init__(self, normalized_shape, lnorm, *,
                 eps: float = 1e-5):

        super().__init__()

        if isinstance(normalized_shape, int):
            normalized_shape = torch.Size([normalized_shape])
        elif isinstance(normalized_shape, list):
            normalized_shape = torch.Size(normalized_shape)
        assert isinstance(normalized_shape, torch.Size)

        self.normalized_shape = normalized_shape
        self.eps = eps

        self.gain = lnorm.weight
        self.bias = lnorm.bias

    def taylor_sqrt(self, x):
        u = -2.5310791822175722e-05;
        u = u * x + 0.0011722736211444353;
        u = u * x + -0.020442863849468841;
        u = u * x + 0.16784042867490598;
        u = u * x + -0.6769490522498337;
        return u * x + 1.5386227127260448

    def forward(self, x: torch.Tensor):

        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

        dims = [-(i + 1) for i in range(len(self.normalized_shape))]

        mean = x.mean(dim=dims, keepdim=True)

        mean_x2 = (x ** 2).mean(dim=dims, keepdim=True)
        var = mean_x2 - mean ** 2
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        x_norm2 = (x - mean) * self.taylor_sqrt(var + self.eps)
        x_norm = self.gain * x_norm + self.bias
        x_norm2 = self.gain * x_norm2 + self.bias

        print(x_norm,x_norm2)

        return x_norm2

# Encrypted layer normalization

In [None]:
def create_ctx(bits_scale=40, poly_mod_degree=16384, num_mul=5):
    """Helper for creating the CKKS context.
    CKKS params:
        - Polynomial degree: 8192.
        - Coefficient modulus size: [40, 21, 21, 21, 21, 21, 21, 40]. # 24 binary digit
        - Scale: 2 ** 21. # 24 ov kara
        - The setup requires the Galois keys for evaluating the convolutions.
    """

    coeff_mod_bit_sizes = [40]
    for i in range(num_mul):
        coeff_mod_bit_sizes.append(bits_scale)
    coeff_mod_bit_sizes.append(40)

    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
    ctx.global_scale = pow(2, bits_scale)
    ctx.generate_galois_keys()

    return ctx

In [None]:

context = create_ctx(poly_mod_degree=32768, num_mul=15)


In [None]:
def remez_sqrt(x):
    u = -2.5310791822175722e-05
    u = u * x + 0.0011722736211444353
    u = u * x + -0.020442863849468841
    u = u * x + 0.16784042867490598
    u = u * x + -0.6769490522498337
    return u * x + 1.5386227127260448


def norm(x_enc):
    mean = x_enc.dot([1] * x_enc.shape[0]) * (1 / x_enc.shape[0])
    var = (((x_enc - mean) ** 2).dot([1] * x_enc.shape[0]) * (1 / x_enc.shape[0]))
    return (x_enc - mean) * remez_sqrt(var)

In [None]:
import numpy as np

x_ = np.linspace(0.8, 16, 300)
y_ = 1 / (x_ ** (1 / 2))

In [None]:
enc_x_ = ts.ckks_vector(context, x_)
y2_ = remez_sqrt(enc_x_).decrypt()


In [None]:
np.abs(y_ - y2_).std()

In [None]:
plt.plot(x_, y2_, label=f"Fully Encrypted Remez Polynomial with degree 5")
plt.plot(x_, y_, label=f"1/sqrt(x) function")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.title(r'Y=1/sqrt(x) Polynomial Approximation')

In [None]:
# data = torch.rand(256)*11+0.8
data = np.random.rand(256) * 11 + 0.8

In [None]:
data.var()

In [None]:
enc_v1 = ts.ckks_vector(context, data)

In [None]:
enc_res = torch.Tensor(norm(enc_v1).decrypt())

In [None]:
non_res = norm(data)

In [None]:
np.abs(enc_res - non_res).std()

In [None]:
np.abs(enc_res - non_res).mean()

In [None]:
enc_res = (enc_v1.dot([1] * enc_v1.shape[0]) * (1 / enc_v1.shape[0]))
var = (((enc_v1 - enc_res) ** 2).dot([1] * enc_v1.shape[0]) * (1 / enc_v1.shape[0]))

In [None]:
var.decrypt()

In [None]:
np.var(data)

In [None]:
non_res = (data - np.mean(data)) / np.std(data)

In [None]:
np.abs(enc_res - non_res).mean()

In [None]:


sns.histplot(torch.abs(enc_res - non_res))
plt.title("Error distribution of encrypted layer normalization")
plt.xlabel("Error")
plt.ylabel("Count")
plt.savefig("error_dist.png", dpi=300)

In [None]:
torch.abs(enc_res - non_res).mean()

# Main model architecture and training

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # layer_norm_input1.append(torch.var(x,dim=2).to("cpu"))
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        # layer_norm_input2.append(torch.var(x,dim=2).to("cpu"))
        x = x + self.linear(self.layer_norm_2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
            self,
            embed_dim,
            hidden_dim,
            num_channels,
            num_heads,
            num_layers,
            num_classes,
            patch_size,
            num_patches,
            dropout=0.0,
    ):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels * (patch_size ** 2), embed_dim)
        self.transformer = nn.Sequential(
            *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
        )

        self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)

        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)

        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]
        x = self.dropout(x)
        x = x.transpose(0, 1)

        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]

        out = self.mlp_head(cls)
        return out

    def encforward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)

        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)

        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]
        x = self.dropout(x)
        x = x.transpose(0, 1)

        lnorm_layer = self.transformer[0].layer_norm_1
        self.transformer[0].layer_norm_1 = EncLayerNorm(256, lnorm_layer)
        x = self.transformer(x)
        self.transformer[0].layer_norm_1 = lnorm_layer
        # Perform classification prediction
        cls = x[0]

        out = self.mlp_head(cls)
        return out

In [None]:
class ViT(L.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss

    def _calculate_loss_test(self, batch, mode="test"):
        imgs, labels = batch

        preds_enc = self.model.encforward(imgs)
        loss_enc = F.cross_entropy(preds_enc, labels)
        acc_enc = (preds_enc.argmax(dim=-1) == labels).float().mean()

        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log("%s_loss_enc" % mode, loss_enc)
        self.log("%s_acc_enc" % mode, acc_enc)
        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss_test(batch, mode="test")

In [None]:
def train_model(**kwargs):
    trainer = L.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
        accelerator="cpu",
        devices=1,
        max_epochs=1,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),
            LearningRateMonitor("epoch"),
        ],
    )
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model at %s, loading..." % pretrained_filename)
        # Automatically loads the model with the saved hyperparameters
        model = ViT.load_from_checkpoint(pretrained_filename)
    else:
        L.seed_everything(42)  # To be reproducable
        model = ViT(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        # Load best checkpoint after training
        model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    print(trainer.checkpoint_callback.best_model_path)

    # Test best model on validation and test set
    val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
    print(test_result)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result

In [None]:
model, results = train_model(
    model_kwargs={
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 8,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 2,
        "dropout": 0.2,
    },
    lr=3e-4,
)
print("ViT results", results)