In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
import pandas as pd

import pickle, os, sys

# imagebind
from imagebind.data import load_and_transform_vision_data 
import encoder.custom_ibvis_encoder as cibv
import encoder.custom_ib_model as cibm
from constantinople import Constantinople
from info_nce import InfoNCE
from tqdm import tqdm
from custom_logger import get_logger



In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [5]:
def get_model():
    if os.path.exists('constantinople_512.pth'):
        model = torch.load('constantinople_512.pth')
    model = Constantinople()
    model.load_state_dict(torch.load('.checkpoints/constantinople_pretrained_512.pth'), strict = False)
    return model

In [6]:
model = get_model()

In [7]:
def freeze_model(model, exception = []):
    for name, child in model.named_children():
        if name in exception:
            for n, param in child.named_parameters():
                param.requires_grad = True
                print(name, n, param.requires_grad)
        else:
            for param in child.parameters():
                param.requires_grad = False
            freeze_model(child, exception)            

In [8]:
freeze_model(model, ['modality_postprocessors', 'modality_heads', 'polling'])

modality_heads 0.weight True
modality_heads 0.bias True
modality_heads 2.weight True
polling weight True
polling bias True


In [9]:
model.train()
model.to(device)

Constantinople(
  (image_encoder): CustomIbvisEncoder(
    (modality_preprocessors): RGBDTPreprocessor(
      (cls_token): tensor((1, 1, 1280), requires_grad=False)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Sequential(
          (0): PadIm2Video()
          (1): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
        )
      )
      (pos_embedding_helper): SpatioTemporalPosEmbeddingHelper(
        (pos_embed): tensor((1, 257, 1280), requires_grad=False)
        
      )
    )
    (modality_trunks): SimpleTransformer(
      (pre_transformer_layer): Sequential(
        (0): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (1): EinOpsRearrange()
      )
      (blocks): Sequential(
        (0): BlockWithMasking(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1280, out_features=1280, bias=True)
          )
          (drop_path): Identity()
          (norm_1): LayerNorm((1280,

In [10]:
def get_dataset(dataset_path = 'data/constantinople_dataset.pkl'):
    if os.path.exists(dataset_path):
        with open(dataset_path, 'rb') as f:
            dataset = pickle.load(f)
        return dataset
    
    with open('data/constantinople_data.pkl', 'rb') as f:
        data = pickle.load(f)
    
    portion = torch.Tensor(list(data['portion'].values())).to(device)
    touch = torch.Tensor(list(data['touch'].values())).to(device)
    path = list(data['image'].values())
    
    image = load_and_transform_vision_data(path, device)
    dataset = TensorDataset(image, portion, touch)
    print(portion.dtype, touch.dtype, image.dtype)
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)
    
    return dataset

In [11]:
dataset = get_dataset()

In [12]:
train_size = 80
val_size = 7
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(train_size, val_size)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

80 7


In [13]:
criterion = InfoNCE()
optimizer = optim.Adam(model.parameters(), lr=0.005)

In [14]:
# 모델 학습 및 검증
num_epochs = 5
best_val_loss = float('inf')
logger = get_logger()

In [15]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs=25):
    model.train()
    train_loss = 0
    for epoch in range(num_epochs):
        running_loss = 0.0
        batch = 1
        for image_em, portion, labels in train_loader:
            logger.info(f'train batch {batch} start')
            image_em, portion ,labels = image_em.to(device), portion.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(image_em, portion)
            logger.info(f'train batch {batch} output : {outputs.shape} ')

            loss = criterion(outputs, labels)
            logger.info(f'train batch {batch} loss : {loss:.4f} ')

            loss.backward()
            optimizer.step()
            running_loss += loss.item() * image_em.size(0)
            logger.info(f'train batch {batch} end, running_loss: {running_loss:.4f}')
            batch += 1

        epoch_loss = running_loss / len(train_loader.dataset)
        train_loss = epoch_loss
        logger.info(f'Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss:.4f}')
    return train_loss

# 검증 함수
def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        batch = 1
        for image_em, portion, labels in val_loader:
            logger.info(f'valid batch {batch} start')

            image_em, portion ,labels = image_em.to(device), portion.to(device), labels.to(device)
            outputs = model(image_em, portion)
            logger.info(f'train batch {batch} output : {outputs.shape} ')

            loss = criterion(outputs, labels)
            logger.info(f'train batch {batch} loss : {loss:.4f} ')
            
            val_loss += loss.item() * image_em.size(0)
            logger.info(f'valid batch {batch} end, running_loss: {val_loss:.4f}')
            batch +=1
    val_loss /= len(val_loader.dataset)
    logger.info(f'Validation Loss: {val_loss:.4f}')
    return val_loss


In [18]:
train_loss = []
val_loss = []
if os.path.exists('constantinople_best.pth'):
    model.load_state_dict(torch.load('constantinople_best.pth'))
    

for epoch in range(1):
    # train, valid loss 계산
    train_loss.append(train_model(model, train_loader, criterion, optimizer, device, num_epochs=5))
    val_loss.append(evaluate_model(model, val_loader, criterion, device))
    # valid_loss 최솟값 저장
    logger.info(f'TOTAL RUN {epoch}/{num_epochs - 1}, Train_Loss: {train_loss[-1]:.4f},  VAoid_Loss: {val_loss[-1]:.4f}')
    if val_loss[-1] < best_val_loss:
        best_val_loss = val_loss[-1]
        torch.save(model.state_dict(), f'constantinople_best')
        logger.info(f'UPDATE best_val_loss :{best_val_loss:.4f}')

        
    

In [17]:
logger.info(f'train :{train_loss}')
logger.info(f'valid :{val_loss}')
logger.info(f'train end')