In [None]:
import os
from random import shuffle
import PIL.Image as Image
import torch
from torch.utils.data import Dataset

class TensorFolder(Dataset):
    def __init__(self, root, target_transform=None):
        """
        Args:
            root (str): Root directory with subfolders per class.
            latents_transform (callable, optional): Optional transform to apply to latents.
            target_transform (callable, optional): Optional transform to apply to labels.
        """
        self.root = root
        self.target_transform = target_transform

        # Discover classes
        self.classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        # Collect all tensor file paths with labels
        self.samples = []
        for cls in self.classes:
            cls_dir = os.path.join(root, cls)
            for fname in os.listdir(cls_dir):
                if fname.endswith(".pt"):
                    path = os.path.join(cls_dir, fname)
                    self.samples.append((path, self.class_to_idx[cls]))
        # temporary
        # self.samples=self.samples[:2]*1024
        # shuffle(self.samples)
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        sample = torch.load(path,map_location='cpu')   # Load pre-saved tensor
        im = Image.open(sample['im_path'])
        clip_emb = sample['clip_emb'].detach()
        if self.target_transform is not None:
            im = self.target_transform(im)
        
        return im,clip_emb,label

In [None]:
from random import shuffle
import random
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from kemsekov_torch.train import split_dataset

def fix_channels(im):
    im=im[:3]
    if im.shape[0]==1:
        im=im[[0,0,0]]
    return im

tr = T.Compose([
    T.ToTensor(),
    T.Resize(256),
    T.RandomCrop((256,256)),
    T.Lambda(fix_channels),
    T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])

dataset = TensorFolder('./latents/',tr)
random_state = 123
torch.random.manual_seed(random_state)
random.seed(random_state)

# split dataset into train and test
train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    num_workers=16,
    batch_size=16,
    random_state=random_state,
    # bin_by_size=True
)
dataset.classes

In [None]:
import matplotlib.pyplot as plt
from random import randint
from vae import decode

# Set up a 4x4 grid for displaying images
plt.figure(figsize=(10,10))

for i in range(4):
    for j in range(4):
        index = randint(0, len(dataset) - 1)       # Random index from dataset
        sample = dataset[index]                    # Select a random sample
        im, emb_mu,label = sample
        # decode latents
        plt.subplot(4,4,i*4+j+1)
        plt.title(dataset.classes[label])
        # Display image on the selected subplot
        plt.imshow(T.ToPILImage()(im.sigmoid()))
        plt.axis("off")                             # Hide axes for clean view
plt.tight_layout()
plt.show()

In [None]:
from kemsekov_torch.common_modules import Residual
from kemsekov_torch.residual import ResidualBlock
from kemsekov_torch.attention import LinearSelfAttentionBlock,LinearCrossAttentionBlock, EfficientSpatialChannelAttention
from kemsekov_torch.attention import MultiHeadLinearAttention
import torch.nn as nn
import torch

class CrossAttention(nn.Module):
    def __init__(self,in_channels,context_channels,internal_dim=128):
        super().__init__()
        def norm(ch):
            # return nn.Identity()
            return nn.RMSNorm(ch)
            # return nn.LayerNorm(ch)
        
        self.input_2_internal = Residual([
            nn.Linear(in_channels,internal_dim)
            # norm(internal_dim)
        ])
        
        self.context_2_internal = nn.Linear(context_channels,internal_dim)
        self.time = nn.Sequential(
            nn.Linear(1,internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim,internal_dim),
        )
        self.context_norm = norm(internal_dim)

        self.sa_QKV =nn.Sequential(
            nn.Linear(
                internal_dim,
                internal_dim*3,
            )
        )
        self.sa_norm = norm(internal_dim)
        self.lsa = MultiHeadLinearAttention(
            internal_dim,
            n_heads=max(4,internal_dim//16),
            dropout=0,
            use_classic_attention=True,
            add_rotary_emb=True
        )
        
        self.cross_norm = norm(internal_dim)
        self.lca = MultiHeadLinearAttention(
            internal_dim,
            n_heads=max(4,internal_dim//16),
            dropout=0,
            use_classic_attention=True
        )
        
        self.cross_Q = nn.Sequential(
            nn.Linear(
                internal_dim,
                internal_dim,
            )
        )
        
        self.cross_KV = nn.Sequential(
            nn.Linear(
                internal_dim,
                internal_dim*2,
            )
        )
        self.mlp_norm = norm(internal_dim)
        self.mlp = Residual([
            nn.Linear(internal_dim,4*internal_dim),
            nn.GELU(),
            nn.Linear(4*internal_dim,in_channels),
        ],init_at_zero=True)
        
    def forward(self,x,context,time):
        x_input = x
        x,context = x.transpose(1,-1),context.transpose(1,-1)
        x = self.input_2_internal(x)
        context = self.context_2_internal(context)
        context=context+self.time(time)
        
        q,k,v = self.sa_QKV(self.sa_norm(x)).chunk(3,-1)
        x = self.lsa(q,k,v)[0]+x
         
        q = self.cross_Q(self.cross_norm(x))
        k,v = self.cross_KV(self.context_norm(context)).chunk(2,-1)
        x = self.lca(q,k,v)[0]+x
        
        return self.mlp(self.mlp_norm(x)).transpose(1,-1)+x_input

class FlowMatchingModel(nn.Module):
    def __init__(
        self, 
        in_channels, 
        context_dim,
        expand_dim = 128,
        residual_block_repeats = 1,
        ):
        super().__init__()
        norm = 'batch'
        self.context_dim=context_dim
        self.expand = nn.Conv2d(in_channels,expand_dim,1)
        
        self.down1 = nn.Sequential(
            ResidualBlock(expand_dim,residual_block_repeats*[expand_dim*2],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*2)
        )
        
        self.down2 = nn.Sequential(
            ResidualBlock(expand_dim*2,residual_block_repeats*[expand_dim*4],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*4)
        )
        
        self.down3 = nn.Sequential(
            ResidualBlock(expand_dim*4,residual_block_repeats*[expand_dim*8],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        
        self.down4 = nn.Sequential(
            ResidualBlock(expand_dim*8,residual_block_repeats*[expand_dim*16],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        self.attn4 = CrossAttention(expand_dim*16,context_dim,expand_dim*16)
        
        self.down5 = nn.Sequential(
            ResidualBlock(expand_dim*16,residual_block_repeats*[expand_dim*32],4,stride=2,normalization=norm),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        self.attn5 = CrossAttention(expand_dim*32,context_dim,expand_dim*32)
        
        self.up1 = nn.Sequential(
            ResidualBlock(expand_dim*32,residual_block_repeats*[expand_dim*16],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*16)
        )
        
        self.up2 = nn.Sequential(
            ResidualBlock(expand_dim*16,residual_block_repeats*[expand_dim*8],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*8)
        )
        
        self.up3 = nn.Sequential(
            ResidualBlock(expand_dim*8,residual_block_repeats*[expand_dim*4],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*4)
        )
        
        self.up4 = nn.Sequential(
            ResidualBlock(expand_dim*4,residual_block_repeats*[expand_dim*2],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim*2)
        )
        
        self.up5 = nn.Sequential(
            ResidualBlock(expand_dim*2,residual_block_repeats*[expand_dim],4,stride=2,normalization=norm).transpose(),
            # EfficientSpatialChannelAttention(expand_dim)
        )
        
        self.final = ResidualBlock(
            expand_dim,
            [expand_dim,in_channels],
            3,
            normalization=norm
        )

    def forward(self,x, context : torch.Tensor, time):
        if time.dim()<2:
            time = time[:,None]
        x_orig = x
        # make it wider
        time=time*5-2.5
        
        x=self.expand(x)
        
        d1 = self.down1(x)
        # d1=self.attn1(d1,context,time)
        
        d2 = self.down2(d1)
        # d2=self.attn2(d2,context,time)
        
        d3 = self.down3(d2)
        
        d4 = self.down4(d3)
        d4 = self.attn4(d4,context,time)
        
        d5 = self.down5(d4)
        d5 = self.attn5(d5,context,time)
        
        u1 = self.up1(d5)+d4
        u2 = self.up2(u1)+d3
        u3 = self.up3(u2)+d2
        u4 = self.up4(u3)+d1
        u5 = self.up5(u4)+x
        
        return self.final(u5)

In [None]:
im, emb,label = dataset[0]
model = FlowMatchingModel(
    im.shape[0],
    emb.shape[-1],
    expand_dim=16,
    residual_block_repeats = 1
)
time = torch.Tensor([0.1])

model(im[None,:].float(),emb[None,:].float(),time).shape

In [None]:
from kemsekov_torch.train import train
from kemsekov_torch.metrics import r2_score
from kemsekov_torch.flow_matching import FlowMatching
from kemsekov_torch.muon import muon_optimizer

fm = FlowMatching()
loss = nn.MSELoss()

contrastive_lambda = 0.05
def compute_loss_and_metric(model,batch):
    latent, emb_mu ,label = batch
    
    emb_mu[torch.randperm(len(emb_mu))[:len(emb_mu)//2]]=0
    
    def run_model(x,t):
        return model(x,emb_mu,t)
    
    x0 = torch.randn_like(latent)
    pred,target,contrastive_dir,t = fm.contrastive_flow_matching_pair(run_model,x0,latent)
    
    loss_ = loss(pred,target) - contrastive_lambda*loss(pred,contrastive_dir)
    return loss_,{
        'r2':r2_score(pred,target)
    }

epochs = 1000

optim = torch.optim.AdamW(model.parameters(),1e-3)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader)*epochs)
train(
    model,
    train_loader,
    test_loader,
    compute_loss_and_metric,
    "runs/ims-natural",
    # "runs/ims-natural/last",
    num_epochs=epochs,
    checkpoints_count=1,
    save_on_metric_improve=['r2'],
    gradient_clipping_max_norm=1,
    accelerate_args={
        'mixed_precision':'bf16',
        'gradient_accumulation_steps':4,
        # 'dynamo_backend':'inductor'
    },
    ema_args={
        'beta':0.995,
        'power':1,
        'use_foreach':True
    },
    optimizer=optim,
    scheduler=sch
)