In [1]:
from einops import rearrange
import copy
import h5py
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from pdb import set_trace
import matplotlib.pyplot as plt
from torch import nn
from x_transformers import  Encoder, Decoder
from x_transformers.autoregressive_wrapper import exists
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from fastai.vision.all import BCEWithLogitsLossFlat
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from fastprogress.fastprogress import master_bar, progress_bar
import os
from timm import create_model

In [2]:
class CFG:
    bs = 128
    nw = 4
    model_name = "VIT"
    lr = 1e-4
    wd = 1e-4
    epoch = 50
    warmup_pct = 0.1
    num_classes = 1
    dropout_rate = 0.3
    folder = "EXP_120_00_VIT_36"
    mixup=False
    exp_name = f"{folder}_{model_name}"

In [3]:
def get_snr(left, right, df):
    df_ = pd.concat([df.query(f"snr>{left} & snr<{right}"), df.query("snr==0")])
    return df_


def generate_report(df, p, fn):
    pred = torch.sigmoid(p).cpu().numpy().reshape(-1)
    val_df_eval = df.copy()
    val_df_eval["pred"] = pred
    val_df_eval.to_csv(f"{fn}_oof.csv")

    roc_100 = roc_auc_score(val_df_eval["target"], val_df_eval["pred"])
    roc_0_50 = roc_auc_score(
        get_snr(0, 50, val_df_eval)["target"], get_snr(0, 50, val_df_eval)["pred"]
    )
    roc_15_50 = roc_auc_score(
        get_snr(15, 50, val_df_eval)["target"], get_snr(15, 50, val_df_eval)["pred"]
    )
    roc_25_50 = roc_auc_score(
        get_snr(25, 50, val_df_eval)["target"], get_snr(25, 50, val_df_eval)["pred"]
    )
    roc_0_40 = roc_auc_score(
        get_snr(0, 40, val_df_eval)["target"], get_snr(0, 40, val_df_eval)["pred"]
    )

    roc_0_30 = roc_auc_score(
        get_snr(0, 30, val_df_eval)["target"], get_snr(0, 30, val_df_eval)["pred"]
    )

    return {
        "roc_all": roc_100,
        "roc_0_50": roc_0_50,
        "roc_15_50": roc_15_50,
        "roc_25_50": roc_25_50,
        "roc_0_40": roc_0_40,
        "roc_0_30": roc_0_30,
    }

class SaveModel:
    def __init__(self, folder, exp_name, best=np.inf):
        self.best = best
        self.folder = Path(folder) / f"{exp_name}.pth"

    def __call__(self, score, model, epoch):
        if score < self.best:
            self.best = score
            print(f"Better model found at epoch {epoch} with value: {self.best}.")
            torch.save(model.state_dict(), self.folder)


class SaveModelMetric:
    def __init__(self, folder, exp_name, best=-np.inf):
        self.best = best
        self.folder = Path(folder) / f"{exp_name}.pth"

    def __call__(self, score, model, epoch):
        if score > self.best:
            self.best = score
            print(f"Better model found at epoch {epoch} with value: {self.best}.")
            torch.save(model.state_dict(), self.folder)


class SaveModelEpoch:
    def __init__(self, folder, exp_name, best=-np.inf):
        self.best = best
        self.folder = Path(folder)
        self.exp_name = exp_name

    def __call__(self, score, model, epoch):
        self.best = score
        print(f"Better model found at epoch {epoch} with value: {self.best}.")
        torch.save(model.state_dict(), f"{self.folder/self.exp_name}_{epoch}.pth")


def custom_auc_score(p, gt):
    return roc_auc_score(gt.cpu().numpy(),  torch.sigmoid(p).cpu().numpy().reshape(-1))


def fit_mixup(
    epochs,
    model,
    train_dl,
    valid_dl,
    loss_fn,
    opt,
    metric,
    val_df,
    folder="models",
    exp_name="exp_00",
    device=None,
    sched=None,
    mixup_=False,
    save_md=SaveModel,
):
    if device is None:
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )

    os.makedirs(folder, exist_ok=True)
    loss_fn_trn = loss_fn
    if mixup_:
        mixup = Mixup(num_classes=2, mixup_alpha=0.4, prob=0.8)
        loss_fn_trn = BinaryCrossEntropy()
    mb = master_bar(range(epochs))

    mb.write(
        [
            "epoch",
            "train_loss",
            "valid_loss",
            "val_metric",
            "roc_all",
            "roc_0_50",
            "roc_15_50",
            "roc_25_50",
            "roc_0_40",
            "roc_0_30",
        ],
        table=True,
    )
    model.to(device)  # we have to put our model on gpu
    #scaler = torch.cuda.amp.GradScaler()  # this for half precision training
    save_md = save_md(folder, exp_name)

    for i in mb:  # iterating  epoch
        trn_loss, val_loss = 0.0, 0.0
        trn_n, val_n = len(train_dl.dataset), len(valid_dl.dataset)
        model.train()  # set model for training
        for (xb, yb) in progress_bar(train_dl, parent=mb):
            xb, yb = xb.to(device), yb.to(device)  # putting batches to device
            if mixup_:
                xb, yb = mixup(xb, yb)
           
            out = model(xb)  # forward pass
            loss = loss_fn_trn(out, yb)  # calulation loss

            trn_loss += loss.item()
            #print(loss.item())
            opt.zero_grad()  # zeroing optimizer
            loss.backward()  # backward
            opt.step()  # optimzers step
            if sched is not None:
                sched.step()  # scuedular step

        trn_loss /= mb.child.total

        # putting model in eval mode
        model.eval()
        gt = []
        pred = []
        # after epooch is done we can run a validation dataloder and see how are doing
        with torch.no_grad():
            for (xb, yb) in progress_bar(valid_dl, parent=mb):
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = loss_fn(out, yb)
                val_loss += loss.item()

                gt.append(yb.detach())
                pred.append(out.detach())
        # calculating metric
        metric_ = metric(torch.cat(pred), torch.cat(gt))
        # saving model if necessary
        save_md(metric_, model, i)
        val_loss /= mb.child.total
        dict_res = generate_report(val_df, torch.cat(pred), f"{folder}/{exp_name}_{i}")

        pd.DataFrame(
            {
                "trn_loss": [trn_loss],
                "val_loss": [val_loss],
                "metric": [metric_],
                "roc_all": [dict_res["roc_all"]],
                "roc_0_50": [dict_res["roc_0_50"]],
                "roc_15_50": [dict_res["roc_15_50"]],
                "roc_25_50": [dict_res["roc_25_50"]],
                "roc_0_40": [dict_res["roc_0_40"]],
                "roc_0_30": [dict_res["roc_0_30"]],
            }
        ).to_csv(f"{folder}/{exp_name}_{i}.csv", index=False)
        mb.write(
            [
                i,
                f"{trn_loss:.6f}",
                f"{val_loss:.6f}",
                f"{metric_:.6f}",
                f"{dict_res['roc_all']:.6f}",
                f"{dict_res['roc_0_50']:.6f}",
                f"{dict_res['roc_15_50']:.6f}",
                f"{dict_res['roc_25_50']:.6f}",
                f"{dict_res['roc_0_40']:.6f}",
                f"{dict_res['roc_0_30']:.6f}",
            ],
            table=True,
        )
    print("Training done")
    # loading the best checkpoint

In [4]:
class DataV0():
    """
    dataset = Dataset(data_type, df)

    img, y = dataset[i]
      img (np.float32): 2 x 360 x 128
      y (np.float32): label 0 or 1
    """
    def __init__(self, df, freq_tfms=False):
        self.df = df
        self.tfms = freq_tfms
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        """
        i (int): get ith data
        """
        r = self.df.iloc[i]
        y = np.float32(r.target)
        img = np.array(torch.load(r.id)['s_p_n'])
#        filename=r.id
#        file_id = Path(r.id).stem
#        img = np.empty((2, 360, 128), dtype=np.float32)
#        with h5py.File(filename, 'r') as f:
#            g = f[file_id]
#            try:
#                freq = f["frequency_Hz"][1:]
#            except:
#                freq = g["frequency_Hz"][:]
#            for ch, s in enumerate(['H1', 'L1']):
#                a = g[s]['SFTs'][:, :4096] * 1e22  # Fourier coefficient complex64
#                p = a.real**2 + a.imag**2  # power
#                p/= p.mean()
#                p = np.mean(p.reshape(360, 128, 32), axis=2)  # compress 4096 -> 128
#                img[ch] = p #normalize(p, clip=True)
#        img = img - img.mean()
#        img = img / img.std()
        
        if self.tfms:
            if np.random.rand() <= 0.5:  # horizontal flip
                img = np.flip(img, axis=1).copy()
            if np.random.rand() <= 0.5:  # vertical flip
                img = np.flip(img, axis=2).copy()
            if np.random.rand() <= 0.5:  # vertical shift
                img = np.roll(img, np.random.randint(low=0, high=img.shape[1]), axis=1)
        return torch.tensor(img), y

In [5]:
trn_df = pd.read_csv('../data/SPLITS/V_20/trn_df.csv')
trn_df['id'] = trn_df['id'].apply(lambda x: Path(x.replace('.h5', '.pth')))
val_df = pd.read_csv('../data/SPLITS/V_20/val_df.csv')

In [6]:
torch.load(trn_df['id'].iloc[0])

{'s_p_n': tensor([[[-1.3145, -0.9405, -1.1675,  ..., -0.7726, -0.2216,  0.9371],
          [ 0.0865, -0.9202,  0.7688,  ..., -0.7958,  0.5009,  0.8917],
          [-0.2969, -0.1678, -0.8466,  ..., -1.3051,  0.1425,  0.5670],
          ...,
          [-0.3614,  0.1445, -0.1689,  ..., -0.4277, -0.3333,  0.8875],
          [ 1.0967,  0.2559,  0.2302,  ..., -1.1830, -0.7766,  0.2369],
          [-1.0134,  0.3110,  1.2747,  ..., -0.5835, -1.7268, -0.0904]],
 
         [[ 1.1191,  0.5861,  0.4250,  ..., -0.8543,  1.4359, -0.6437],
          [-0.0519, -0.1483,  0.1070,  ..., -0.0138, -1.0058, -0.7802],
          [ 0.6024,  3.6046,  1.8719,  ..., -1.6742, -1.2916, -1.2031],
          ...,
          [ 0.1899,  0.4270,  0.4975,  ...,  1.6455, -0.5919,  2.5433],
          [ 1.8684,  1.3297, -0.1776,  ..., -1.1528,  0.1790, -0.1275],
          [ 0.5175,  0.9300, -1.1015,  ..., -0.3297, -1.1373, -0.3492]]])}

In [7]:
#image_size = 360
#patch_size = 36
#channels = 2
#num_patches = image_size // patch_size
#patch_dim = channels * patch_size ** 2

In [8]:
#x = rearrange(
#    img.unsqueeze(0), "b c (h p1) (w p2) -> b (h w) c p1 p2", p1=num_patches, p2=128
#)[0]
#

In [9]:
class ViTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        attn_layers,
        channels=2,
        num_classes=1,
        dropout=0.0,
        post_emb_norm=False,
        emb_dropout=0,
        lenth=128,
    ):
        super().__init__()
        assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder"
        dim = attn_layers.dim
        num_patches = image_size // patch_size
        self.patch_size = (num_patches, lenth)
        patch_dim = self.patch_size[0] * self.patch_size[1] * channels
        self.pos_embedding = nn.Parameter(torch.randn(1, patch_size + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
        self.dropout = nn.Dropout(emb_dropout)

        self.attn_layers = attn_layers
        self.norm = nn.LayerNorm(dim)
        self.mlp_head = (
            nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
        )

    def forward(self, img, return_embeddings=False):
        p = self.patch_size

        x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p[0], p2=p[1])
        x = self.patch_to_embedding(x)
        n = x.shape[1]

        x = x + self.pos_embedding[:, :n]

        x = self.post_emb_norm(x)
        x = self.dropout(x)

        x = self.attn_layers(x)
        x = self.norm(x)

        if not exists(self.mlp_head) or return_embeddings:
            return x

        x = x.mean(dim=-2)
        return self.mlp_head(x)

In [10]:
# Train - val split
fold =0
trn_ds = DataV0(trn_df, True)
vld_ds = DataV0(val_df)

trn_dl = DataLoader(
    trn_ds,
    batch_size=CFG.bs,
    shuffle=True,
    num_workers=CFG.nw,
    pin_memory=True,
    drop_last=True,
)
vld_dl = DataLoader(
    vld_ds,
    batch_size=CFG.bs,
    shuffle=False,
    num_workers=CFG.nw,
    pin_memory=True,
)

custom_model = ViTransformerWrapper(
    image_size = 360,
    patch_size = 10,
    channels = 2,
    attn_layers = Encoder(
        dim = 1024,
        depth = 8,
        heads = 16
    )
)


opt = torch.optim.AdamW(custom_model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
loss_func = BCEWithLogitsLossFlat()
warmup_steps = int(len(trn_dl) * int(CFG.warmup_pct * CFG.epoch))
total_steps = int(len(trn_dl) * CFG.epoch)
sched = get_linear_schedule_with_warmup(
    opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

fit_mixup(
    epochs=CFG.epoch,
    model=custom_model,
    train_dl=trn_dl,
    valid_dl=vld_dl,
    loss_fn=loss_func,
    opt=opt,
    val_df=val_df,
    metric=custom_auc_score,
    folder=CFG.folder,
    exp_name=f"{CFG.exp_name}_{fold}",
    device="cuda:0",
    sched=sched,
)

epoch,train_loss,valid_loss,val_metric,roc_all,roc_0_50,roc_15_50,roc_25_50,roc_0_40,roc_0_30
0,0.708315,0.703645,0.468591,0.468591,0.456939,0.456939,0.458442,0.456598,0.440107
1,0.706639,0.694749,0.474931,0.474931,0.469192,0.469192,0.468743,0.474946,0.471516
2,0.697335,0.696737,0.47841,0.47841,0.457346,0.457346,0.456771,0.449046,0.484228
3,0.696033,0.695278,0.475707,0.475707,0.458671,0.458671,0.456361,0.470449,0.468646
4,0.695712,0.69548,0.486224,0.486224,0.46748,0.46748,0.469812,0.460508,0.478716
5,0.694333,0.691235,0.58216,0.58216,0.517488,0.517488,0.519064,0.521713,0.498588
6,0.669491,0.693767,0.628319,0.628319,0.520741,0.520741,0.526945,0.498919,0.485267
7,0.579441,0.615685,0.681927,0.681927,0.547768,0.547768,0.550106,0.509319,0.509251
8,0.537481,0.596045,0.738616,0.738616,0.591657,0.591657,0.6019,0.543674,0.510827
9,0.515147,0.56398,0.752278,0.752278,0.612063,0.612063,0.616671,0.569594,0.547103


Better model found at epoch 0 with value: 0.46859111111111107.
Training done
