In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

kstar_shot_list = pd.read_csv('./dataset/KSTAR_Disruption_Shot_List_extend.csv', encoding = "euc-kr")
ts_data = pd.read_csv("./dataset/KSTAR_Disruption_ts_data_for_multi.csv")
mult_info = pd.read_csv("./dataset/KSTAR_Disruption_multi_data.csv")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.models.ViViT import ViViT

video_model = ViViT(
    image_size = 128,
    patch_size = 16,
    n_classes = 2,
    n_frames = 21,
    dim = 64,
    depth = 4,
    n_heads = 8,
    pool = "cls",
    in_channels = 3,
    d_head = 64,
    dropout = 0.25,
    embedd_dropout=0.25,
    scale_dim = 4
)

video_model.summary()

----------------------------------------------------------------------------
      Layer (type)              Input Shape         Param #     Tr. Param #
       Rearrange-1     [1, 21, 3, 128, 128]               0               0
          Linear-2         [1, 21, 64, 768]          49,216          49,216
         Dropout-3          [1, 21, 65, 64]               0               0
     Transformer-4             [21, 65, 64]         658,048         658,048
     Transformer-5              [1, 22, 64]         658,048         658,048
       LayerNorm-6                  [1, 64]             128             128
          Linear-7                  [1, 64]             130             130
Total params: 1,365,570
Trainable params: 1,365,570
Non-trainable params: 0
----------------------------------------------------------------------------



ViViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange(), 0 params
    (1): Linear(in_features=768, out_features=64, bias=True), 49,216 params
  ), 49,2

In [3]:
video_model.mlp[0]

LayerNorm((64,), eps=1e-05, elementwise_affine=True)

In [4]:
from src.models.ConvLSTM import ConvLSTM

ts_model = ConvLSTM(
    seq_len = 21,
    col_dim = 9,
)

ts_model.summary()

-------------------------------------------------------------------------------------------
      Layer (type)                             Input Shape         Param #     Tr. Param #
          Conv1d-1                              [1, 9, 21]             896             896
     BatchNorm1d-2                             [1, 32, 21]              64              64
            ReLU-3                             [1, 32, 21]               0               0
          Conv1d-4                             [1, 32, 21]           3,104           3,104
     BatchNorm1d-5                             [1, 32, 21]              64              64
            ReLU-6                             [1, 32, 21]               0               0
            LSTM-7     [32, 1, 21], [2, 1, 64], [2, 1, 64]          44,544          44,544
          Linear-8                            [1, 32, 128]           8,256           8,256
          Linear-9                             [1, 32, 64]           4,160           4,16

In [5]:
from src.models.mult_modal import MultiModalNetwork

args_video = {
    "image_size" : 128, 
    "patch_size" : 32, 
    "n_frames" : 21, 
    "dim": 64, 
    "depth" : 4, 
    "n_heads" : 8, 
    "pool" : 'cls', 
    "in_channels" : 3, 
    "d_head" : 64, 
    "dropout" : 0.25,
    "embedd_dropout":  0.25, 
    "scale_dim" : 4
}

args_0D = {
    "seq_len" : 21, 
    "col_dim" : 9, 
    "conv_dim" : 32, 
    "conv_kernel" : 3,
    "conv_stride" : 1, 
    "conv_padding" : 1,
    "lstm_dim" : 64, 
}

args_fusion = {
    "kernel_size" : 4,
    "stride" : 2,
    "maxpool_kernel" : 3,
    "maxpool_stride" : 2,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model= MultiModalNetwork(2, args_video, args_0D, args_fusion)
model.to(device)
model.summary(device, True, False, True, False)

TypeError: __init__() takes 4 positional arguments but 5 were given