In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

kstar_shot_list = pd.read_csv('./dataset/KSTAR_Disruption_Shot_List_extend.csv', encoding = "euc-kr")
ts_data = pd.read_csv("./dataset/KSTAR_Disruption_ts_data_for_multi.csv")
mult_info = pd.read_csv("./dataset/KSTAR_Disruption_multi_data.csv")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ts_data.head()

Unnamed: 0,time,\q95,\ipmhd,\kappa,\tritop,\tribot,\betap,\betan,\li,\WTOT_DLM03
0,3.135,4.451424,-517687.502496,1.592205,0.290356,0.669179,0.60462,0.798851,1.393783,113.760088
1,3.154048,4.451229,-517758.959331,1.592658,0.290501,0.66938,0.604886,0.799086,1.394316,113.898366
2,3.173095,4.451524,-517773.000585,1.593156,0.290629,0.669634,0.606237,0.800684,1.393948,114.889598
3,3.192143,4.451977,-517768.464867,1.593668,0.290753,0.669905,0.607941,0.802722,1.393288,116.156805
4,3.21119,4.449676,-517497.61765,1.593433,0.288556,0.670172,0.607891,0.802569,1.395915,116.355835


In [3]:
mult_info.head()

Unnamed: 0,frame_start,frame_end,is_disrupt,shot,task,path,t_start,t_end,t_start_index,t_end_index
0,656,740,False,21273,train,./dataset/dur84_dis4/train/normal/21273_656_740,3.115953,3.515952,0,20
1,740,824,False,21273,valid,./dataset/dur84_dis4/valid/normal/21273_740_824,3.515952,3.915952,21,41
2,824,908,False,21273,train,./dataset/dur84_dis4/train/normal/21273_824_908,3.915952,4.315952,42,62
3,908,992,False,21273,train,./dataset/dur84_dis4/train/normal/21273_908_992,4.315952,4.715952,63,83
4,992,1076,False,21273,train,./dataset/dur84_dis4/train/normal/21273_992_1076,4.715952,5.115952,84,104


In [4]:
kstar_shot_list.head()

Unnamed: 0,shot,year,tftsrt,tipminf,tTQend,dt,frame_cutoff,frame_tTQend,frame_tipminf
0,21273,2018,2.996,5.535,5.514,0.021,1165,1160,1164
1,21274,2018,2.996,10.056,10.038,0.018,2104,2100,2103
2,21310,2018,1.5,5.368,5.342,0.026,1131,1125,1130
3,21315,2018,1.5,7.804,7.782,0.022,1636,1631,1635
4,21317,2018,1.5,9.46,9.438,0.022,1980,1975,1979


In [5]:
from torch.utils.data import Dataset
from typing import Optional, Literal, List, Union
from tqdm.auto import tqdm
from src.CustomDataset import DEFAULT_TS_COLS
import os, cv2

class MultiModalDataset(Dataset):
    def __init__(
        self, 
        task : Literal["train", "valid", "test"] = "train", 
        ts_data : Optional[pd.DataFrame] = None,
        ts_cols : Optional[List] = None,
        mult_info : Optional[pd.DataFrame] = None,
        dt : Optional[float] = 1.0 / 210 * 4,
        distance : Optional[int] = 0,
        n_fps : Optional[int] = 4,
        resize_height : Optional[int] = 256,
        resize_width : Optional[int] = 256,
        crop_size : Optional[int] = 128,
        seq_len : int = 21,
        n_classes : int = 2,
        ):
        self.task = task # task : train / valid / test 
        
        # resize each frame from video
        self.resize_height = resize_height
        self.resize_width = resize_width
        
        # crop
        self.crop_size = crop_size
        
        # video sequence length
        # warning : 0D data and video data should have equal sequence length
        self.seq_len = seq_len
        
        # use for 0D data prediction
        self.distance = distance # prediction time
        self.dt = dt # time difference of 0D data
        self.n_fps = n_fps

        # video_file_path : video file path : {database}/{shot_num}_{frame_start}_{frame_end}.avi
        # indices : index for tabular data, shot == shot_num, index <- df[df.frame_idx == frame_start].index
        self.n_classes = n_classes

        self.ts_data = ts_data
        self.mult_info = mult_info
        self.ts_cols = ts_cols
        
        # select columns for 0D data prediction
        if ts_cols is None:
            self.ts_cols = DEFAULT_TS_COLS
            
        self.video_file_path = mult_info[mult_info.task == task]["path"].values.tolist()
        self.labels = [0 if label is True else 1 for label in mult_info[mult_info.task == task].is_disrupt]
        self.indices = mult_info[mult_info.task == task]["t_start_index"].astype(int).values.tolist()

    def load_frames(self, file_dir : str):
        frames = sorted([os.path.join(file_dir, img) for img in os.listdir(file_dir)])
        frame_count = self.seq_len
        buffer = np.empty((frame_count, self.resize_height, self.resize_width, 3), np.dtype('float32'))
        
        for i, frame_name in enumerate(frames[::-1][::self.n_fps][::-1]):
            frame = np.array(cv2.imread(frame_name)).astype(np.float32)
            buffer[i] = frame
    
        return buffer
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx : int):
        x_video = self.get_video_data(idx)
        x_tabular = self.get_tabular_data(idx)
        label = torch.from_numpy(np.array(self.labels[idx]))
        return x_video, x_tabular, label

    def get_video_data(self, index : int):
        buffer = self.load_frames(self.video_file_path[index])
        if buffer.shape[0] < self.seq_len:
            buffer = self.refill_temporal_slide(buffer)
        buffer = self.crop(buffer, self.seq_len, self.crop_size)
        buffer = self.normalize(buffer)
        buffer = self.to_tensor(buffer)
        return torch.from_numpy(buffer)
    
    def get_tabular_data(self, index : int):
        ts_idx = self.indices[index]
        data = self.ts_data[self.ts_cols].loc[ts_idx:ts_idx+self.seq_len-1].values
        return torch.from_numpy(data)

    def refill_temporal_slide(self, buffer:np.ndarray):
        for _ in range(self.seq_len - buffer.shape[0]):
            frame_new = buffer[-1].reshape(1, self.resize_height, self.resize_width, 3)
            buffer = np.concatenate((buffer, frame_new))
        return buffer

    def normalize(self, buffer):
        for i, frame in enumerate(buffer):
            frame -= np.array([[[90.0, 98.0, 102.0]]])
            buffer[i] = frame
        return buffer

    def to_tensor(self, buffer:Union[np.ndarray, torch.Tensor]):
        return buffer.transpose((3, 0, 1, 2))

    def crop(self, buffer : Union[np.ndarray, torch.Tensor], clip_len : int, crop_size : int, is_random : bool = False):
        if buffer.shape[0] < clip_len :
            time_index = np.random.randint(abs(buffer.shape[0] - clip_len))
        elif buffer.shape[0] == clip_len :
            time_index = 0
        else :
            time_index = np.random.randint(buffer.shape[0] - clip_len)

        if not is_random:
            original_height = self.resize_height
            original_width = self.resize_width
            mid_x, mid_y = original_height // 2, original_width // 2
            offset_x, offset_y = crop_size // 2, crop_size // 2
            buffer = buffer[time_index : time_index + clip_len, mid_x - offset_x:mid_x+offset_x, mid_y - offset_y: mid_y+ offset_y, :]
        else:
            height_index = np.random.randint(buffer.shape[1] - crop_size)
            width_index = np.random.randint(buffer.shape[2] - crop_size)

            buffer = buffer[time_index:time_index + clip_len,
                    height_index:height_index + crop_size,
                    width_index:width_index + crop_size, :]
        return buffer

    # function for imbalanced dataset
    # used for LDAM loss and re-weighting
    def get_num_per_cls(self):
        classes = np.unique(self.labels)
        self.num_per_cls_dict = dict()

        for cls in classes:
            num = np.sum(np.where(self.labels == cls, 1, 0))
            self.num_per_cls_dict[cls] = num
         
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.n_classes):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list
    

train_data = MultiModalDataset('train', ts_data, DEFAULT_TS_COLS, mult_info, dt = 1 / 210 * 4, distance = 4, seq_len = 21)
valid_data = MultiModalDataset('valid', ts_data, DEFAULT_TS_COLS, mult_info, dt = 1 / 210 * 4, distance = 4, seq_len = 21)
test_data = MultiModalDataset('test', ts_data, DEFAULT_TS_COLS, mult_info, dt = 1 / 210 * 4, distance = 4, seq_len = 21)

from torch.utils.data import DataLoader
from src.utils.sampler import ImbalancedDatasetSampler

batch_size = 32
sampler = ImbalancedDatasetSampler(train_data)
train_loader = DataLoader(train_data, batch_size = batch_size, num_workers = 8, sampler = sampler)
valid_loader = DataLoader(valid_data, batch_size = batch_size, num_workers = 8, shuffle = True)
test_loader = DataLoader(test_data, batch_size = batch_size, num_workers = 8, shuffle = True)

sample_video, sample_0D, sample_target = next(iter(train_loader))
print("sample_video : ", sample_video.size())
print("sample_0D : ", sample_0D.size())
print("sample_target : ", sample_target.size())

sample_video :  torch.Size([32, 3, 21, 128, 128])
sample_0D :  torch.Size([32, 21, 9])
sample_target :  torch.Size([32])


In [11]:
model(sample_video.to(device), sample_0D.to(device))

RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same

In [6]:
from typing import Optional, List, Literal, Union
from src.loss import LDAMLoss, FocalLoss
import torch
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score

def train_per_epoch(
    train_loader : DataLoader, 
    model : torch.nn.Module,
    optimizer : torch.optim.Optimizer,
    scheduler : Optional[torch.optim.lr_scheduler._LRScheduler],
    loss_fn : torch.nn.Module,
    device : str = "cpu",
    max_norm_grad : Optional[float] = None
    ):

    model.train()
    model.to(device)

    train_loss = 0
    train_acc = 0

    total_pred = np.array([])
    total_label = np.array([])
    total_size = 0

    for batch_idx, (x_video, x_0D, target) in enumerate(train_loader):
        optimizer.zero_grad()
        x_video = x_video.to(device)
        x_0D = x_0D.to(device)
        target = target.to(device)
        
        print("x_video : ", x_video.size())
        print("x_0D : ", x_0D.size())
        print("target : ", target.size())
    
        output = model(x_video, x_0D)
        loss = loss_fn(output, target)

        loss.backward()
        
        print("output : ", output.size())

        # use gradient clipping
        if max_norm_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm_grad)

        optimizer.step()

        train_loss += loss.item()

        pred = torch.nn.functional.softmax(output, dim = 1).max(1, keepdim = True)[1]
        train_acc += pred.eq(target.view_as(pred)).sum().item()
        total_size += x_video.size(0) 
        
        total_pred = np.concatenate((total_pred, pred.cpu().numpy().reshape(-1,)))
        total_label = np.concatenate((total_label, target.cpu().numpy().reshape(-1,)))
        
    if scheduler:
        scheduler.step()

    train_loss /= total_size
    train_acc /= total_size

    train_f1 = f1_score(total_label, total_pred, average = "macro")

    return train_loss, train_acc, train_f1

def valid_per_epoch(
    valid_loader : DataLoader, 
    model : torch.nn.Module,
    optimizer : torch.optim.Optimizer,
    loss_fn : torch.nn.Module,
    device : str = "cpu",
    ):

    model.eval()
    model.to(device)
    valid_loss = 0
    valid_acc = 0

    total_pred = np.array([])
    total_label = np.array([])
    total_size = 0

    for batch_idx, (x_video, x_0D, target) in enumerate(valid_loader):
        with torch.no_grad():
            optimizer.zero_grad()
            x_video = x_video.to(device)
            x_0D = x_0D.to(device)
            target = target.to(device)
        
            output = model(x_video, x_0D)

            loss = loss_fn(output, target)
    
            valid_loss += loss.item()
            pred = torch.nn.functional.softmax(output, dim = 1).max(1, keepdim = True)[1]
            valid_acc += pred.eq(target.view_as(pred)).sum().item()
            total_size += x_video.size(0) 

            total_pred = np.concatenate((total_pred, pred.cpu().numpy().reshape(-1,)))
            total_label = np.concatenate((total_label, target.cpu().numpy().reshape(-1,)))

    valid_loss /= total_size
    valid_acc /= total_size

    valid_f1 = f1_score(total_label, total_pred, average = "macro")

    return valid_loss, valid_acc, valid_f1

def train(
    train_loader : DataLoader, 
    valid_loader : DataLoader,
    model : torch.nn.Module,
    optimizer : torch.optim.Optimizer,
    scheduler : Optional[torch.optim.lr_scheduler._LRScheduler],
    loss_fn : Union[torch.nn.CrossEntropyLoss, LDAMLoss, FocalLoss],
    device : str = "cpu",
    num_epoch : int = 64,
    verbose : Optional[int] = 8,
    save_best_dir : str = "./weights/best.pt",
    save_last_dir : str = "./weights/last.pt",
    max_norm_grad : Optional[float] = None,
    criteria : Literal["f1_score", "acc", "loss"] = "f1_score",
    ):

    train_loss_list = []
    valid_loss_list = []
    
    train_acc_list = []
    valid_acc_list = []

    train_f1_list = []
    valid_f1_list = []

    best_acc = 0
    best_epoch = 0
    best_f1 = 0
    best_loss = torch.inf

    for epoch in tqdm(range(num_epoch), desc = "training process"):

        train_loss, train_acc, train_f1 = train_per_epoch(
            train_loader, 
            model,
            optimizer,
            scheduler,
            loss_fn,
            device,
            max_norm_grad
        )

        valid_loss, valid_acc, valid_f1 = valid_per_epoch(
            valid_loader, 
            model,
            optimizer,
            loss_fn,
            device 
        )

        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)

        train_acc_list.append(train_acc)
        valid_acc_list.append(valid_acc)

        train_f1_list.append(train_f1)
        valid_f1_list.append(valid_f1)

        if verbose:
            if epoch % verbose == 0:
                print("epoch : {}, train loss : {:.3f}, valid loss : {:.3f}, train acc : {:.3f}, valid acc : {:.3f}, train f1 : {:.3f}, valid f1 : {:.3f}".format(
                    epoch+1, train_loss, valid_loss, train_acc, valid_acc, train_f1, valid_f1
                ))

        # save the best parameters
        
        if criteria == "acc" and best_acc < valid_acc:
            best_acc = valid_acc
            best_f1 = valid_f1
            best_loss = valid_loss
            best_epoch  = epoch
            torch.save(model.state_dict(), save_best_dir)
        elif criteria == "f1_score" and best_f1 < valid_f1:
            best_acc = valid_acc
            best_f1 = valid_f1
            best_loss = valid_loss
            best_epoch  = epoch
            torch.save(model.state_dict(), save_best_dir)
        elif criteria == "loss" and best_loss > valid_loss:
            best_acc = valid_acc
            best_f1 = valid_f1
            best_loss = valid_loss
            best_epoch  = epoch
            torch.save(model.state_dict(), save_best_dir)

        # save the last parameters
        torch.save(model.state_dict(), save_last_dir)

    # print("\n============ Report ==============\n")
    print("training process finished, best loss : {:.3f} and best acc : {:.3f}, best f1 : {:.3f}, best epoch : {}".format(
        best_loss, best_acc, best_f1, best_epoch
    ))

    return  train_loss_list, train_acc_list, train_f1_list,  valid_loss_list,  valid_acc_list, valid_f1_list

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 1e-3

from src.loss import FocalLoss
from src.models.mult_modal import MultiModalModel

train_data.get_num_per_cls()
cls_num_list = train_data.get_cls_num_list()
per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights)
per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
loss_fn = FocalLoss(per_cls_weights, gamma = 2)

args_video = {
    "image_size" : 128, 
    "patch_size" : 32, 
    "n_frames" : 21, 
    "dim": 64, 
    "depth" : 4, 
    "n_heads" : 8, 
    "pool" : 'cls', 
    "in_channels" : 3, 
    "d_head" : 64, 
    "dropout" : 0.25,
    "embedd_dropout":  0.25, 
    "scale_dim" : 4
}

args_0D = {
    "seq_len" : 21, 
    "col_dim" : 9, 
    "conv_dim" : 32, 
    "conv_kernel" : 3,
    "conv_stride" : 1, 
    "conv_padding" : 1,
    "lstm_dim" : 64, 
}
    
model = MultiModalModel(
    2, args_video, args_0D
)
model.summary('cpu')
model.to(device)

num_epoch = 64
verbose = 4
save_best_dir = "./weights/multi_modal_best.pt"
save_last_dir = "./weights/multi_modal_last.pt"
max_norm_grad = 1.0
criteria = "f1_score"
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)

train_loss, train_acc, train_f1, valid_loss, valid_acc, valid_f1 = train(
    train_loader,
    valid_loader,
    model,
    optimizer,
    None,
    loss_fn,
    device,
    num_epoch,
    verbose,
    save_best_dir,
    save_last_dir,
    max_norm_grad,
    criteria
)

------------------------------------------------------------------------------
        Layer (type)              Input Shape         Param #     Tr. Param #
      ViViTEncoder-1     [1, 3, 21, 128, 128]       1,535,744       1,535,744
   ConvLSTMEncoder-2               [1, 21, 9]          61,088          61,088
            Linear-3                 [1, 192]          18,528          18,528
       BatchNorm1d-4                  [1, 96]             192             192
              ReLU-5                  [1, 96]               0               0
            Linear-6                  [1, 96]             194             194
Total params: 1,615,746
Trainable params: 1,615,746
Non-trainable params: 0
------------------------------------------------------------------------------


training process:   0%|          | 0/64 [00:00<?, ?it/s]

x_video :  torch.Size([32, 3, 21, 128, 128])
x_0D :  torch.Size([32, 21, 9])
target :  torch.Size([32])


training process:   0%|          | 0/64 [00:03<?, ?it/s]


RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same