In [None]:
import os
from random import shuffle
import PIL.Image
import torch
from torch.utils.data import Dataset

class TensorFolder(Dataset):
    def __init__(self, root, target_transform=None):
        """
        Args:
            root (str): Root directory with subfolders per class.
            latents_transform (callable, optional): Optional transform to apply to latents.
            target_transform (callable, optional): Optional transform to apply to labels.
        """
        self.root = root
        self.target_transform = target_transform

        # Discover classes
        self.classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        # Collect all tensor file paths with labels
        self.samples = []
        for cls in self.classes:
            cls_dir = os.path.join(root, cls)
            for fname in os.listdir(cls_dir):
                if fname.endswith(".pt"):
                    path = os.path.join(cls_dir, fname)
                    self.samples.append((path, self.class_to_idx[cls]))
        # temporary
        # self.samples=self.samples[:2]*1024
        # shuffle(self.samples)
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        sample = torch.load(path,map_location='cpu')   # Load pre-saved tensor
        latent = sample['vae_latents'].detach()[0]
        clip_emb = sample['clip_emb'].detach()
        if self.target_transform is not None:
            latent = self.target_transform(latent)
        return latent,clip_emb,label

In [None]:
from random import shuffle
import random
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from kemsekov_torch.train import split_dataset

dataset = TensorFolder('./latents_tree/')
random_state = 123
torch.random.manual_seed(random_state)
random.seed(random_state)

# split dataset into train and test
train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    num_workers=8,
    batch_size=16,
    random_state=random_state,
    bin_by_size=True
)
dataset.classes

In [None]:
# import matplotlib.pyplot as plt
# from random import randint
# from vae import decode

# # Set up a 4x4 grid for displaying images
# plt.figure(figsize=(10,10))

# for i in range(4):
#     for j in range(4):
#         index = randint(0, len(dataset) - 1)       # Random index from dataset
#         sample = dataset[index]                    # Select a random sample
#         latent, emb_mu,label = sample
#         # decode latents
#         image_dec = decode(latent[None,:])[0]
#         plt.subplot(4,4,i*4+j+1)
#         plt.title(dataset.classes[label])
#         # Display image on the selected subplot
#         plt.imshow(T.ToPILImage()(image_dec))
#         plt.axis("off")                             # Hide axes for clean view
# print("Latent size",latent.shape)
# plt.tight_layout()
# plt.show()

In [None]:
from kemsekov_torch.common_modules import Residual
from kemsekov_torch.residual import ResidualBlock
from kemsekov_torch.attention import LinearSelfAttentionBlock, EfficientSpatialChannelAttention
import torch.nn as nn
import torch

class TimeContextEmbedding(nn.Module):
    def __init__(self,input_dim,context_dim,internal_dim) -> None:
        super().__init__()
        
        def norm(ch):
            return nn.GroupNorm(16,ch)
        
        self.input_2_internal = Residual([
            nn.Conv2d(input_dim,internal_dim,1)
        ])
        self.context_2_internal = Residual([
            nn.Linear(context_dim,internal_dim)
        ])
        self.time = nn.Sequential(
            nn.Linear(1,internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim,internal_dim),
        )
        self.context_norm = nn.RMSNorm(internal_dim)
        
        self.output = Residual([
            nn.Conv2d(internal_dim,input_dim,1)
        ])
        
    def forward(self,x,context,time):
        x = self.input_2_internal(x)
        context = self.context_2_internal(context)
        context=(context+self.time(time))
        context=self.context_norm(context)
        if context.ndim>2:
            context=context[0]
        while context.ndim!=x.ndim:
            context=context.unsqueeze(-1)
        x+=context
        return self.output(x)

class FlowMatchingModel(nn.Module):
    def __init__(
        self, 
        in_channels, 
        context_dim,
        expand_dim = 128,
        residual_block_repeats = 1,
        ):
        super().__init__()
        self.context_dim=context_dim
        self.expand = nn.Conv2d(in_channels,expand_dim,1)
        norm='group'
        def down_block(in_ch,out_ch):
            return nn.Sequential(
                ResidualBlock(
                    in_ch,
                    residual_block_repeats*[out_ch],
                    kernel_size=4,
                    stride=2,
                    normalization=norm
                ),
                EfficientSpatialChannelAttention(out_ch)
            )
        def up_block(in_ch,out_ch):
            return nn.Sequential(
                ResidualBlock(
                    in_ch,
                    residual_block_repeats*[out_ch],
                    kernel_size=4,
                    stride=2,
                    normalization=norm
                ).transpose(),
                EfficientSpatialChannelAttention(out_ch)
            )
        
        
        self.down1 = down_block(expand_dim,expand_dim*2)
        self.attn1 = TimeContextEmbedding(expand_dim*2,context_dim,expand_dim*2)
        
        self.down2 = down_block(expand_dim*2,expand_dim*4)
        self.attn2 = TimeContextEmbedding(expand_dim*4,context_dim,expand_dim*4)
        
        self.down3 = down_block(expand_dim*4,expand_dim*8)
        
        self.attn3_1 = TimeContextEmbedding(expand_dim*8,context_dim,expand_dim*8)
        self.attn3_2 = LinearSelfAttentionBlock(expand_dim*8,mlp_dim=expand_dim*8,heads=16,add_gating=True)
        
        self.up1 = up_block(expand_dim*8,expand_dim*4)
        self.up1_combine=nn.Sequential(
            nn.Conv2d(8*expand_dim,4*expand_dim,1)
        )
        
        self.up2 = up_block(expand_dim*4,expand_dim*2)
        self.up2_combine=nn.Sequential(
            nn.Conv2d(4*expand_dim,2*expand_dim,1)
        )
        
        self.up3 = up_block(expand_dim*2,expand_dim)
        self.up3_combine=nn.Sequential(
            nn.Conv2d(2*expand_dim,expand_dim,1)
        )
        
        self.final = ResidualBlock(
            expand_dim,
            [expand_dim,in_channels],
            3,
            normalization=norm
        )

    def forward(self,x, context : torch.Tensor, time):
        if time.dim()<2:
            time = time[:,None]
        orig_x=x
        # make it wider
        time=time*5-2.5
        
        x=self.expand(x)
        
        d1 = self.down1(x)
        d1 = self.attn1(d1,context,time)
        
        d2 = self.down2(d1)
        d2 = self.attn2(d2,context,time)
        
        d3 = self.down3(d2)
        d3 = self.attn3_1(d3,context,time)
        d3 = self.attn3_2(d3.transpose(1,-1)).transpose(1,-1)
        
        u1 = self.up1(d3)
        u1 = self.up1_combine(torch.concat([u1,d2],1))
        u2 = self.up2(u1)
        u2 = self.up2_combine(torch.concat([u2,d1],1))
        u3 = self.up3(u2)
        u3 = self.up3_combine(torch.concat([u3,x],1))
        
        return self.final(u3)

In [None]:
latent, emb,label = dataset[0]
model = FlowMatchingModel(
    latent.shape[0],
    emb.shape[-1],
    expand_dim=128,
    residual_block_repeats = 1
)
time = torch.Tensor([0.1])

model(latent[None,:].float(),emb[None,:].float(),time).shape

In [None]:
from kemsekov_torch.train import train
from kemsekov_torch.metrics import r2_score
from kemsekov_torch.flow_matching import FlowMatching

fm = FlowMatching()
loss = nn.MSELoss()

contrastive_lambda = 0.1
def compute_loss_and_metric(model,batch):
    latent, emb_mu ,label = batch
    
    emb_mu[torch.randperm(len(emb_mu))[:len(emb_mu)//2]]=0
    
    def run_model(x,t):
        return model(x,emb_mu,t)
    
    x0 = torch.randn_like(latent)
    pred,target,contrastive_dir,t = fm.contrastive_flow_matching_pair(run_model,x0,latent)
    
    loss_ = loss(pred,target) - contrastive_lambda*loss(pred,contrastive_dir)
    return loss_,{
        'r2':r2_score(pred,target)
    }

epochs = 200
optim = torch.optim.AdamW(model.parameters(),1e-3)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader)*epochs)

train(
    model,
    train_loader,
    test_loader,
    compute_loss_and_metric,
    "runs/vae-tree",
    # "runs/vae-tree/last",
    num_epochs=epochs,
    checkpoints_count=1,
    skip_n_epochs_before_checkpoint=10,
    save_on_metric_improve=['r2'],
    gradient_clipping_max_norm=1,
    accelerate_args={
        'mixed_precision':'bf16',
        # 'dynamo_backend':'inductor'
    },
    # ema_args={
    #     'beta':0.995,
    #     'power':1,
    #     'use_foreach':True
    # },
    optimizer=optim,
    # scheduler=sch
)