# Play with pipeline

### Import

In [1]:
import os
PROJECT_PATH=os.getenv("IBENS_PROJECT_PATH") #Si bug: source ~/.bashrc
os.chdir(PROJECT_PATH)

import lightning as L
from lightning import LightningDataModule, LightningModule

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import h5py

from disdiff_adaptaters.data_module.shapes3d import Shapes3DDataModule
from disdiff_adaptaters.data_module.bloodmnist import BloodMNISTDataModule

from disdiff_adaptaters.utils.utils import load_h5, collate_images
from disdiff_adaptaters.utils.const import Shapes3D, ChAda, BloodMNIST

from chada.backbones.vit.chada_vit import ChAdaViT

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"
print(f"device is at {device}")

device is at cuda


### 3DShapes

In [None]:
if device == "cuda" :
    data_module = Shapes3DDataModule()
    data_module.prepare_data()
    data_module.setup(stage='fit')
    train_loader = data_module.train_dataloader()
else :
    images, labels = load_h5(Shapes3D.Path.H5)
    images = images[:100]
    labels = labels[:100]
    train_loader = DataLoader(TensorDataset(torch.tensor(images).permute(0,3,1,2), torch.tensor(labels)), batch_size=8, shuffle=True)

start split from h5 file


### BloodMNIST

In [3]:
data = BloodMNISTDataModule()
data.prepare_data()
data.setup("fit")

In [4]:
loader = data.train_dataloader()
for batch in loader : 
    print(batch)

[tensor([[[[0.8588, 0.8784, 0.9412,  ..., 0.7882, 0.8039, 0.8157],
          [0.8314, 0.8667, 0.9686,  ..., 0.8667, 0.8745, 0.8784],
          [0.8275, 0.8627, 0.9569,  ..., 0.8941, 0.8863, 0.8824],
          ...,
          [0.9255, 0.9529, 0.9569,  ..., 0.7333, 0.8510, 0.9804],
          [0.9843, 0.9882, 0.9725,  ..., 0.7373, 0.8431, 0.9608],
          [0.9725, 0.9804, 0.9686,  ..., 0.7294, 0.8196, 0.9137]],

         [[0.7137, 0.7373, 0.8039,  ..., 0.6353, 0.6471, 0.6627],
          [0.6980, 0.7373, 0.8431,  ..., 0.7137, 0.7216, 0.7294],
          [0.6941, 0.7294, 0.8353,  ..., 0.7490, 0.7412, 0.7373],
          ...,
          [0.8196, 0.8471, 0.8510,  ..., 0.6157, 0.7333, 0.8627],
          [0.8902, 0.8941, 0.8706,  ..., 0.6039, 0.7098, 0.8314],
          [0.8902, 0.8980, 0.8745,  ..., 0.5961, 0.6863, 0.7804]],

         [[0.6784, 0.6824, 0.7255,  ..., 0.7176, 0.7412, 0.7373],
          [0.6588, 0.6784, 0.7608,  ..., 0.7843, 0.7961, 0.7882],
          [0.6588, 0.6824, 0.7647,  ..., 

KeyboardInterrupt: 

## Encoder

### ChAda

In [5]:
# Params
PATCH_SIZE = 16
EMBED_DIM = 192
RETURN_ALL_TOKENS = False
MAX_NUMBER_CHANNELS = 10

In [6]:
model = ChAdaViT(
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    return_all_tokens=RETURN_ALL_TOKENS,
    max_number_channels=MAX_NUMBER_CHANNELS,
)

In [8]:
state = torch.load(ChAda.Path.WEIGHTS, map_location="cpu", weights_only=False)["state_dict"]
for k in list(state.keys()):
    if "encoder" in k:
        state[k.replace("encoder", "backbone")] = state[k]
    if "backbone" in k:
        state[k.replace("backbone.", "")] = state[k]
    del state[k]
model.load_state_dict(state, strict=False)
model.to("cpu")
model.eval()
model.mixed_channels = True

In [19]:
print(images.shape, labels.shape)

(480000, 64, 64, 3) (480000, 6)


In [None]:
data = []
for batch in train_loader :
    images, labels = batch
    data_batch = []
    for image, label in zip(images, labels) :
        data_batch.append(images[i],)

    break

<class 'torch.Tensor'>


In [None]:
for batch in train_loader : 
    #unpack batch
    images, labels = batch

    #Format batch in order to collate
    data = []
    for image, label in zip(images, labels) :
        data.append((image, label))
    #Collate images
    channels_list, labels_list, num_channels_list = collate_images(data)

    #Forward
    feats = model(channels_list, index=0, list_num_channel=num_channels_list)
    break

In [22]:
output

tensor([[ 1.3813, -0.8161,  1.0642, -1.8680,  1.3986,  1.1567,  0.6350, -0.0222,
          1.5576,  1.9850,  1.5999,  0.4893,  0.5862,  2.1450, -0.9456, -0.6991,
         -1.2963, -1.2560, -2.1907, -1.9560,  0.1671,  1.2622, -1.2758, -0.8391,
         -0.1157, -1.4872,  1.7005,  0.4986,  0.8209, -1.6139, -1.3855,  0.4419,
         -0.5098, -0.3220,  1.1783,  1.1320, -0.4826, -0.9351, -0.7490, -0.3625,
          0.9052, -0.0501,  0.2452, -0.2175,  0.1449,  0.4216,  0.0167,  0.5555,
          1.0213,  0.9465,  0.8716,  2.6457, -0.9982, -0.3541,  0.2795, -0.2061,
          0.2320, -0.7463, -0.5710, -0.3109,  1.9990,  2.0090,  0.1810,  1.7208,
         -0.3301, -1.1345,  0.5531, -0.2745, -0.2627, -1.6694, -1.1357, -0.2337,
          0.1739,  1.8382,  0.1912, -1.3422,  1.4479, -0.5516, -0.2163, -1.6908,
         -0.4320,  0.0068, -0.3481, -1.8300, -2.1380,  0.4883,  0.9845,  0.4305,
          1.6311, -0.5753,  1.4146, -0.2970,  0.0283,  0.8475, -0.3458,  0.7038,
          0.7426, -1.4992,  

### Variational Encoder