# Play with 3dshapes

### Import

In [1]:
import os
PROJECT_PATH="/home/alexandre/disdiff_adaptaters"
os.chdir(PROJECT_PATH)

import lightning as L
from lightning import LightningDataModule, LightningModule

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import h5py

from disdiff_adaptaters.data_module.shapes3d import Shapes3DDataModule
from disdiff_adaptaters.utils.utils import load_h5
from disdiff_adaptaters.utils.const import Shapes3D

from src.backbones.vit.chada_vit import ChAdaViT

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [2]:
images, labels = load_h5(Shapes3D.Path.H5)
images = images[:100]
labels = labels[:100]
print(images.shape, labels.shape)

(100, 64, 64, 3) (100, 6)


In [3]:
local = True 
if not local :
    data_module = Shapes3DDataModule()
    data_module.prepare_data()
    data_module.setup(stage='fit')
    train_loader = data_module.train_dataloader()
else :
    images, labels = load_h5(Shapes3D.Path.H5)
    images = images[:100]
    labels = labels[:100]
    train_loader = DataLoader(TensorDataset(torch.tensor(images).permute(0,3,1,2), torch.tensor(labels)), batch_size=8, shuffle=True)

## Load encoder

In [4]:
# Params
PATCH_SIZE = 16
EMBED_DIM = 192
RETURN_ALL_TOKENS = False
MAX_NUMBER_CHANNELS = 10

CKPT_PATH = "/home/alexandre/disdiff_adaptaters/disdiff_adaptaters/arch/Dino-IDRCell100k-vit_c-embed_dim_192_patch_16-310682-ep=399.ckpt"

In [5]:
model = ChAdaViT(
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    return_all_tokens=RETURN_ALL_TOKENS,
    max_number_channels=MAX_NUMBER_CHANNELS,
)

In [6]:
state = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)["state_dict"]
for k in list(state.keys()):
    if "encoder" in k:
        state[k.replace("encoder", "backbone")] = state[k]
    if "backbone" in k:
        state[k.replace("backbone.", "")] = state[k]
    del state[k]
model.load_state_dict(state, strict=False)
model.to("cpu")
model.eval()

ChAdaViT(
  (token_learner): TokenLearner(
    (proj): Conv2d(1, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (linear1): Linear(in_features=192, out_features=2048, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (linear2): Linear(in_features=2048, out_features=192, bias=True)
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.0, inplace=False)
      (dropout2): Dropout(p=0.0, inplace=False)
    )
  )
  (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (head): Identity()
)

In [10]:
for batch in train_loader : 
    images, labels = batch
    model((images/255.0).to(torch.float32)[:,1,:,:].unsqueeze(1), index=0, list_num_channels=len(images)*[3])