In [1]:
import os
import torch
from torch.utils.data import Dataset

class TensorFolder(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        """
        Args:
            root (str): Root directory with subfolders per class.
            transform (callable, optional): Optional transform to apply to tensors.
            target_transform (callable, optional): Optional transform to apply to labels.
        """
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        # Discover classes
        self.classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        # Collect all tensor file paths with labels
        self.samples = []
        for cls in self.classes:
            cls_dir = os.path.join(root, cls)
            for fname in os.listdir(cls_dir):
                if fname.endswith(".pt"):
                    path = os.path.join(cls_dir, fname)
                    self.samples.append((path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        tensor = torch.load(path)   # Load pre-saved tensor

        if self.transform:
            tensor = self.transform(tensor)
        if self.target_transform:
            label = self.target_transform(label)

        return tensor, label


In [2]:
from random import shuffle
import random
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from kemsekov_torch.train import split_dataset


dataset = TensorFolder('./latents/')
random_state = 123
torch.random.manual_seed(random_state)
random.seed(random_state)

# split dataset into train and test
train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    num_workers=16,
    batch_size=16,
    random_state=random_state,
    bin_by_size=True
)
dataset.classes

  from .autonotebook import tqdm as notebook_tqdm
Bin by tensor size: 100%|██████████| 10676/10676 [00:03<00:00, 3305.30it/s]


unique 10674
total 10800


Bin by tensor size: 100%|██████████| 562/562 [00:00<00:00, 3025.68it/s]

unique 562
total 640
Train items 10800
Test items 640





['Birds', 'cat', 'dog']

In [3]:
# import matplotlib.pyplot as plt
# from random import randint
# from vae import decode

# # Set up a 4x4 grid for displaying images
# plt.figure(figsize=(10,10))

# for i in range(4):
#     for j in range(4):
#         index = randint(0, len(dataset) - 1)       # Random index from dataset
#         sample = dataset[index]                    # Select a random sample
#         image, label = sample[0], sample[1]        # Separate image and label
#         # decode latents
#         image_dec = decode(image[None,:])[0]
#         plt.subplot(4,4,i*4+j+1)
#         plt.title(dataset.classes[label])
#         # Display image on the selected subplot
#         plt.imshow(T.ToPILImage()(image_dec))
#         plt.axis("off")                             # Hide axes for clean view
# print("Latent size",image.shape)
# plt.tight_layout()
# plt.show()

In [4]:
from typing import List
from kemsekov_torch.resunet import ResidualUnet
from kemsekov_torch.residual import ResidualBlock
from kemsekov_torch.attention import LinearSelfAttentionBlock, EfficientSpatialChannelAttention
from kemsekov_torch.common_modules import Transpose, Repeat
from kemsekov_torch.positional_emb import AddPositionalEmbeddingPermute
import torch.nn.functional as F
import torch.nn as nn
def get_lsa(c,repeats):
    return nn.Sequential(
        Transpose(1,-1),
        Repeat(
            LinearSelfAttentionBlock(
                c,
                c,
                max(4,min(c//16,32)),
                add_rotary_emb=True,
                use_classic_attention=True,
                add_gating=True,
                add_local_attention=False
            ),
            repeats
        ),
        Repeat(
            LinearSelfAttentionBlock(
                c,
                c,
                max(4,min(c//16,32)),
                add_rotary_emb=True,
                use_classic_attention=True,
                add_gating=True,
                add_local_attention=False
            ),
            repeats
        ),
        Transpose(1,-1),
    )
class FlowMatchingModel(nn.Module):
    def __init__(self, in_channels,classes : List[str]):
        super().__init__()
        self.num_classes=len(classes)
        self.classes=classes
        channels = [64,128,256,512]
    
        common = dict(
            normalization='group',
            activation=torch.nn.SiLU,
        )
        self.unet = ResidualUnet(
            in_channels,
            in_channels,
            channels=channels,
            repeats=2,
            attention = EfficientSpatialChannelAttention,
            bottom_layer=get_lsa(channels[-1],2),
            **common
        )
        
        self.middle = nn.ModuleList([
            nn.Sequential(
                ResidualBlock(c,[c*2,c],**common),
                EfficientSpatialChannelAttention(c),
                # get_lsa(c,1) if c>=256 else nn.Identity(),
                ResidualBlock(c,[c*2,c],**common),
                EfficientSpatialChannelAttention(c),
            )
            for c in [channels[0]]+channels
        ])
        
        self.class_to_emb = nn.ModuleList(
            [nn.Linear(num_classes,c) for c in [channels[0]]+channels]
        )
        self.time_emb = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(1,c),
                    nn.GELU(),
                    nn.Linear(c,c),
                )
                for c in [channels[0]]+channels
            ]
        )
        self.combine = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(c*2,c),
                    nn.LayerNorm(c)
                ) 
                for c in [channels[0]]+channels
            ]
        )

   
        
    def forward(self,x,time,cls):
        if cls.dtype==torch.long or cls.dtype==torch.int32:
            # handle empty labels
            no_context = (cls<0) | (cls>=self.num_classes)
            cls[no_context]=0
            cls = torch.nn.functional.one_hot(cls,self.num_classes).float().to(x.device)
            
            # if not label is provided set all values of corresponding elements to 0
            cls[no_context]=0

        if time.dim()<2:
            time = time[:,None]
        
        return self.unet1(x, time, cls)

    def unet1(self, x, time, cls):
        context = []
        for class_emb,time_emb,c in zip(self.class_to_emb,self.time_emb,self.combine):
            cls_emb = class_emb(cls)
            t_emb = time_emb(time)
            
            # try sum and linear combination
            context_ = c(torch.concat([cls_emb,t_emb],-1))[:,:,None,None]
            
            context.append(context_)

        skip = self.unet.encode_with_context(x,context)
        
        for i,m in enumerate(self.middle):
            skip[i]=m(skip[i])+context[i]
        
        # skip = [skip[0]]+[p(s) for s,p in zip(skip[1:],self.add_pos)]
        
        rec = self.unet.decode(skip)
        return rec

num_classes = len(dataset.classes)
model = FlowMatchingModel(4,dataset.classes)

im = torch.randn((4,4,128,128))
time = torch.Tensor([0.1,0.9,0.2,0.3])
classes = torch.Tensor([1,2,3,-1]).long()
model(im,time,classes).shape

torch.Size([4, 4, 128, 128])

In [None]:
from kemsekov_torch.train import train
from kemsekov_torch.metrics import r2_score
from kemsekov_torch.flow_matching import FlowMatching

fm = FlowMatching()
loss = nn.MSELoss()

contrastive_lambda = 0.05
def compute_loss_and_metric(model,batch):
    image,cls = batch
    
    no_label_samples = torch.randperm(len(cls))[:len(cls)//4]
    cls[no_label_samples]=-1
    
    def run_model(x,t):
        return model(x,t,cls)
    x0 = torch.randn_like(image)
    pred,target,contrastive_dir,t = fm.contrastive_flow_matching_pair(run_model,x0,image)
    
    loss_ = loss(pred,target) - contrastive_lambda*loss(pred,contrastive_dir)
    return loss_,{
        'r2':r2_score(pred,target)
    }

epochs = 600
optim = torch.optim.AdamW(model.parameters(),1e-3)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader)*epochs)
train(
    model,
    train_loader,
    test_loader,
    compute_loss_and_metric,
    "runs/vae-natural",
    "runs/vae-natural/last",
    num_epochs=epochs,
    save_on_metric_improve=['r2'],
    gradient_clipping_max_norm=1,
    accelerate_args={
        'mixed_precision':'bf16',
        # 'dynamo_backend':'inductor'
    },
    ema_args={
        'update_after_step':100,
        'beta':0.99,
        'power':1,
        'use_foreach':True
    },
    optimizer=optim,
    scheduler=sch
)

Total model parameters 43.59 M
Using device cuda
Failed to load state with error [Errno 2] No such file or directory: 'runs/vae-natural/last/state/pytorch_model.bin'
Ignoring state loading...
trying to capture model architecture...
Saved model architecture at runs/vae-natural/model.pt. You can torch.load it and update it's weights with checkpoint

Epoch 1/600


train 0: 100%|██████████| 675/675 [00:46<00:00, 14.61it/s, loss=0.6680, r2=0.5119]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.72793 | 0.64002 |
|  r2  | 0.4672  | 0.5198  |
+------+---------+---------+
saved epoch-1

Epoch 2/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.04it/s, loss=0.6512, r2=0.5330]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.62912 | 0.60496 |
|  r2  | 0.5279  | 0.5416  |
+------+---------+---------+
saved epoch-2

Epoch 3/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.98it/s, loss=0.5740, r2=0.5695]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.60105 | 0.57403 |
|  r2  | 0.5454  | 0.5612  |
+------+---------+---------+
saved epoch-3

Epoch 4/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.13it/s, loss=0.6246, r2=0.5413]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.58588 | 0.56786 |
|  r2  | 0.5546  | 0.5643  |
+------+---------+---------+
saved epoch-4

Epoch 5/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.12it/s, loss=0.6721, r2=0.5185]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.57613 | 0.56664 |
|  r2  | 0.5605  | 0.5655  |
+------+---------+---------+
saved epoch-5

Epoch 6/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.92it/s, loss=0.5786, r2=0.5728]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.56871 | 0.54931 |
|  r2  | 0.5652  | 0.5755  |
+------+---------+---------+
saved epoch-6

Epoch 7/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.99it/s, loss=0.7153, r2=0.4828]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.56442 | 0.55203 |
|  r2  | 0.5678  | 0.5734  |
+------+---------+---------+

Epoch 8/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.12it/s, loss=0.5492, r2=0.5887]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.55705 | 0.56014 |
|  r2  | 0.5723  | 0.5703  |
+------+---------+---------+

Epoch 9/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.91it/s, loss=0.6206, r2=0.5496]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.55502 | 0.55964 |
|  r2  | 0.5736  | 0.5695  |
+------+---------+---------+

Epoch 10/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.96it/s, loss=0.6279, r2=0.5394]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.55343 | 0.55679 |
|  r2  | 0.5744  | 0.5711  |
+------+---------+---------+

Epoch 11/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.79it/s, loss=0.5852, r2=0.5669]


+------+--------+---------+
|      | Train  |  Test   |
+------+--------+---------+
| loss | 0.5506 | 0.55243 |
|  r2  | 0.5762 | 0.5745  |
+------+--------+---------+

Epoch 12/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.89it/s, loss=0.6063, r2=0.5628]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.54835 | 0.53775 |
|  r2  | 0.5777  | 0.5834  |
+------+---------+---------+
saved epoch-12

Epoch 13/600


train 0: 100%|██████████| 675/675 [00:39<00:00, 16.94it/s, loss=0.5117, r2=0.6096]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.54761 | 0.54277 |
|  r2  | 0.5781  | 0.5813  |
+------+---------+---------+

Epoch 14/600


train 0: 100%|██████████| 675/675 [00:40<00:00, 16.52it/s, loss=0.5846, r2=0.5668]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.54696 | 0.54162 |
|  r2  | 0.5787  | 0.5806  |
+------+---------+---------+

Epoch 15/600


train 0: 100%|██████████| 675/675 [00:44<00:00, 15.25it/s, loss=0.5353, r2=0.5933]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.54579 | 0.54103 |
|  r2  | 0.5793  | 0.5812  |
+------+---------+---------+

Epoch 16/600


train 0: 100%|██████████| 675/675 [00:42<00:00, 15.78it/s, loss=0.5911, r2=0.5653]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.54278 | 0.52925 |
|  r2  | 0.5811  | 0.5874  |
+------+---------+---------+
saved epoch-16

Epoch 17/600


train 0: 100%|██████████| 675/675 [00:43<00:00, 15.44it/s, loss=0.5563, r2=0.5875]


+------+--------+---------+
|      | Train  |  Test   |
+------+--------+---------+
| loss | 0.5397 | 0.52694 |
|  r2  | 0.5830 | 0.5899  |
+------+--------+---------+
saved epoch-17

Epoch 18/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.20it/s, loss=0.5438, r2=0.5897]


+------+--------+---------+
|      | Train  |  Test   |
+------+--------+---------+
| loss | 0.5368 | 0.53672 |
|  r2  | 0.5847 | 0.5825  |
+------+--------+---------+

Epoch 19/600


train 0: 100%|██████████| 675/675 [00:39<00:00, 16.90it/s, loss=0.5053, r2=0.6159]


+------+---------+--------+
|      |  Train  |  Test  |
+------+---------+--------+
| loss | 0.53744 | 0.5239 |
|  r2  | 0.5844  | 0.5909 |
+------+---------+--------+
saved epoch-19

Epoch 20/600


train 0: 100%|██████████| 675/675 [00:36<00:00, 18.35it/s, loss=0.5367, r2=0.5959]


+------+--------+---------+
|      | Train  |  Test   |
+------+--------+---------+
| loss | 0.5352 | 0.53505 |
|  r2  | 0.5860 | 0.5854  |
+------+--------+---------+

Epoch 21/600


train 0: 100%|██████████| 675/675 [00:35<00:00, 18.80it/s, loss=0.5299, r2=0.5982]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.53762 | 0.52851 |
|  r2  | 0.5841  | 0.5880  |
+------+---------+---------+

Epoch 22/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.01it/s, loss=0.5356, r2=0.5943]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.53585 | 0.53145 |
|  r2  | 0.5853  | 0.5863  |
+------+---------+---------+

Epoch 23/600


train 0: 100%|██████████| 675/675 [00:35<00:00, 19.25it/s, loss=0.5861, r2=0.5641]


+------+--------+---------+
|      | Train  |  Test   |
+------+--------+---------+
| loss | 0.5334 | 0.51942 |
|  r2  | 0.5867 | 0.5947  |
+------+--------+---------+
saved epoch-23

Epoch 24/600


train 0: 100%|██████████| 675/675 [00:34<00:00, 19.55it/s, loss=0.5509, r2=0.5904]


+------+---------+--------+
|      |  Train  |  Test  |
+------+---------+--------+
| loss | 0.53425 | 0.5337 |
|  r2  | 0.5864  | 0.5864 |
+------+---------+--------+

Epoch 25/600


train 0: 100%|██████████| 675/675 [00:35<00:00, 18.95it/s, loss=0.5145, r2=0.6063]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.53169 | 0.52661 |
|  r2  | 0.5877  | 0.5894  |
+------+---------+---------+

Epoch 26/600


train 0: 100%|██████████| 675/675 [00:40<00:00, 16.61it/s, loss=0.5133, r2=0.6095]


+------+--------+---------+
|      | Train  |  Test   |
+------+--------+---------+
| loss | 0.529  | 0.52477 |
|  r2  | 0.5892 | 0.5908  |
+------+--------+---------+

Epoch 27/600


train 0: 100%|██████████| 675/675 [00:45<00:00, 14.96it/s, loss=0.5615, r2=0.5811]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.53035 | 0.52282 |
|  r2  | 0.5888  | 0.5933  |
+------+---------+---------+

Epoch 28/600


train 0: 100%|██████████| 675/675 [00:38<00:00, 17.52it/s, loss=0.5592, r2=0.5836]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52671 | 0.51489 |
|  r2  | 0.5910  | 0.5973  |
+------+---------+---------+
saved epoch-28

Epoch 29/600


train 0: 100%|██████████| 675/675 [00:35<00:00, 18.92it/s, loss=0.4833, r2=0.6331]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52901 | 0.52668 |
|  r2  | 0.5896  | 0.5920  |
+------+---------+---------+

Epoch 30/600


train 0: 100%|██████████| 675/675 [00:40<00:00, 16.79it/s, loss=0.5283, r2=0.5999]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52797 | 0.51151 |
|  r2  | 0.5901  | 0.5996  |
+------+---------+---------+
saved epoch-30

Epoch 31/600


train 0: 100%|██████████| 675/675 [00:42<00:00, 15.91it/s, loss=0.5576, r2=0.5850]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52789 | 0.52717 |
|  r2  | 0.5902  | 0.5904  |
+------+---------+---------+

Epoch 32/600


train 0: 100%|██████████| 675/675 [00:34<00:00, 19.39it/s, loss=0.6393, r2=0.5389]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52469 | 0.51282 |
|  r2  | 0.5921  | 0.5980  |
+------+---------+---------+

Epoch 33/600


train 0: 100%|██████████| 675/675 [00:34<00:00, 19.40it/s, loss=0.5543, r2=0.5876]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52577 | 0.51775 |
|  r2  | 0.5915  | 0.5959  |
+------+---------+---------+

Epoch 34/600


train 0: 100%|██████████| 675/675 [00:34<00:00, 19.36it/s, loss=0.5985, r2=0.5703]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52751 | 0.51553 |
|  r2  | 0.5910  | 0.5968  |
+------+---------+---------+

Epoch 35/600


train 0: 100%|██████████| 675/675 [00:43<00:00, 15.38it/s, loss=0.5182, r2=0.6098]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52375 | 0.52652 |
|  r2  | 0.5931  | 0.5905  |
+------+---------+---------+

Epoch 36/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 18.18it/s, loss=0.5989, r2=0.5605]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52233 | 0.51533 |
|  r2  | 0.5933  | 0.5975  |
+------+---------+---------+

Epoch 37/600


train 0: 100%|██████████| 675/675 [00:37<00:00, 17.85it/s, loss=0.5252, r2=0.6075]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52472 | 0.52252 |
|  r2  | 0.5924  | 0.5923  |
+------+---------+---------+

Epoch 38/600


train 0: 100%|██████████| 675/675 [00:35<00:00, 18.76it/s, loss=0.5691, r2=0.5772]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52505 | 0.51243 |
|  r2  | 0.5922  | 0.5989  |
+------+---------+---------+

Epoch 39/600


train 0: 100%|██████████| 675/675 [00:38<00:00, 17.33it/s, loss=0.5227, r2=0.6030]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52271 | 0.52413 |
|  r2  | 0.5932  | 0.5920  |
+------+---------+---------+

Epoch 40/600


train 0: 100%|██████████| 675/675 [00:41<00:00, 16.19it/s, loss=0.5180, r2=0.6052]


+------+---------+---------+
|      |  Train  |  Test   |
+------+---------+---------+
| loss | 0.52312 | 0.51595 |
|  r2  | 0.5933  | 0.5971  |
+------+---------+---------+

Epoch 41/600


train 0: 100%|██████████| 675/675 [00:39<00:00, 17.05it/s, loss=0.5530, r2=0.5843]
