In [None]:
import os
import torch
from torch.utils.data import Dataset

class TensorFolder(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        """
        Args:
            root (str): Root directory with subfolders per class.
            transform (callable, optional): Optional transform to apply to tensors.
            target_transform (callable, optional): Optional transform to apply to labels.
        """
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        # Discover classes
        self.classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        # Collect all tensor file paths with labels
        self.samples = []
        for cls in self.classes:
            cls_dir = os.path.join(root, cls)
            for fname in os.listdir(cls_dir):
                if fname.endswith(".pt"):
                    path = os.path.join(cls_dir, fname)
                    self.samples.append((path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        tensor = torch.load(path)   # Load pre-saved tensor

        if self.transform:
            tensor = self.transform(tensor)
        if self.target_transform:
            label = self.target_transform(label)

        return tensor, label


In [None]:
from random import shuffle
import random
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from kemsekov_torch.train import split_dataset


dataset = TensorFolder('./latents_tree')
random_state = 123
torch.random.manual_seed(random_state)
random.seed(random_state)

# split dataset into train and test
train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    num_workers=16,
    batch_size=16,
    random_state=random_state,
    bin_by_size=True
)

In [None]:
# import matplotlib.pyplot as plt
# from random import randint
# from vae import decode

# # Set up a 4x4 grid for displaying images
# plt.figure(figsize=(10,10))

# for i in range(4):
#     for j in range(4):
#         index = randint(0, len(dataset) - 1)       # Random index from dataset
#         sample = dataset[index]                    # Select a random sample
#         image, label = sample[0], sample[1]        # Separate image and label
#         # decode latents
#         image_dec = decode(image[None,:])[0]
#         plt.subplot(4,4,i*4+j+1)
#         plt.title(dataset.classes[label])
#         # Display image on the selected subplot
#         plt.imshow(T.ToPILImage()(image_dec))
#         plt.axis("off")                             # Hide axes for clean view
# print("Latent size",image.shape)
# plt.tight_layout()
# plt.show()

In [None]:
from kemsekov_torch.resunet import ResidualUnet
from kemsekov_torch.attention import LinearSelfAttentionBlock, EfficientSpatialChannelAttention
from kemsekov_torch.common_modules import Transpose, Repeat
from kemsekov_torch.positional_emb import AddPositionalEmbeddingPermute
import torch.nn.functional as F
import torch.nn as nn
class FlowMatchingModel(nn.Module):
    def __init__(self, in_channels,num_classes):
        super().__init__()
        self.num_classes=num_classes
        channels = [64,128,256,512]
   
        self.unet = ResidualUnet(
            in_channels,
            in_channels,
            channels=channels,
            repeats=2,
            normalization='group',
            activation=torch.nn.SiLU,
            attention = EfficientSpatialChannelAttention,
            bottom_layer=self.get_lsa(channels,2)
        )
        
        self.class_to_emb = nn.ModuleList(
            [nn.Linear(num_classes,c) for c in channels]
        )
        self.time_emb = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(1,c),
                    nn.GELU(),
                    nn.Linear(c,c),
                )
                for c in channels
            ]
        )
        self.combine = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(c*2,c),
                    nn.LayerNorm(c)
                ) 
                for c in channels
            ]
        )
        self.add_pos = nn.ModuleList(
            [
                AddPositionalEmbeddingPermute(c) for c in channels
            ]
        )
    def get_lsa(self,channels,attention_layers):
        return nn.Sequential(
            Transpose(1,-1),
            
            Repeat(
                LinearSelfAttentionBlock(
                    channels[-1],
                    channels[-1],
                    16,
                    add_rotary_emb=True,
                    use_classic_attention=True,
                    add_gating=True,
                    add_local_attention=False
                ),
                attention_layers
            ),
            # *[
            #     LinearSelfAttentionBlock(
            #         channels[-1],
            #         channels[-1],
            #         16,
            #         add_rotary_emb=True,
            #         use_classic_attention=True,
            #         add_gating=True,
            #         add_local_attention=False
            #     )
            #     for i in range(attention_layers)
            # ],
            Transpose(1,-1),
        )
        
    def forward(self,x,time,cls):
        if cls.dtype==torch.long or cls.dtype==torch.int32:
            # handle empty labels
            no_context = (cls<0) | (cls>=self.num_classes)
            cls[no_context]=0
            cls = torch.nn.functional.one_hot(cls,self.num_classes).float().to(x.device)
            
            # if not label is provided set all values of corresponding elements to 0
            cls[no_context]=0

        if time.dim()<2:
            time = time[:,None]
        
        return self.unet1(x, time, cls)

    def unet1(self, x, time, cls):
        context = []
        for class_emb,time_emb,c in zip(self.class_to_emb,self.time_emb,self.combine):
            cls_emb = class_emb(cls)
            t_emb = time_emb(time)
            
            # try sum and linear combination
            context_ = c(torch.concat([cls_emb,t_emb],-1))[:,:,None,None]
            
            context.append(context_)
        context[-1]*=0
        skip = self.unet.encode_with_context(x,context)
        
        # skip = [skip[0]]+[p(s) for s,p in zip(skip[1:],self.add_pos)]
        
        rec = self.unet.decode(skip)
        return rec

num_classes = len(dataset.classes)
model = FlowMatchingModel(4,num_classes)

im = torch.randn((4,4,128,128))
time = torch.Tensor([0.1,0.9,0.2,0.3])
classes = torch.Tensor([1,2,3,-1]).long()
model(im,time,classes).shape

In [None]:
from kemsekov_torch.train import train
from kemsekov_torch.metrics import r2_score
from kemsekov_torch.flow_matching import FlowMatching

fm = FlowMatching()
loss = nn.MSELoss()

contrastive_lambda = 0.05
def compute_loss_and_metric(model,batch):
    image,cls = batch
    
    no_label_samples = torch.randperm(len(cls))[:len(cls)//4]
    cls[no_label_samples]=-1
    
    def run_model(x,t):
        return model(x,t,cls)
    x0 = torch.randn_like(image)
    pred,target,contrastive_dir,t = fm.contrastive_flow_matching_pair(run_model,x0,image)
    
    loss_ = loss(pred,target) - contrastive_lambda*loss(pred,contrastive_dir)
    return loss_,{
        'r2':r2_score(pred,target)
    }

epochs = 400
optim = torch.optim.AdamW(model.parameters(),1e-4)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader))
train(
    model,
    train_loader,
    test_loader,
    compute_loss_and_metric,
    "runs/vae-tree",
    "runs/vae-tree/last",
    num_epochs=epochs*2,
    save_on_metric_improve=['r2'],
    gradient_clipping_max_norm=1,
    accelerate_args={
        'mixed_precision':'bf16',
        # 'dynamo_backend':'inductor'
    },
    ema_args={
        'update_after_step':100,
        'beta':0.999,
        'update_every':1,
        'power':1,
    },
    optimizer=optim,
    # scheduler=sch
)