In [None]:
from random import shuffle
import random
import numpy as np
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms.v2 as T
from kemsekov_torch.train import split_dataset
from torch.utils.data import Subset     

images_path = '/home/vlad/Documents/image-classification'

RESIZE=(128,128)

def random_square_channel(x):
    if random.randint(0,5)==0:
        ind = random.randint(0,2)
        x[ind]*=x[ind]
    return x
runtime_tr = T.Compose([
    T.RandomCrop([int(RESIZE[0]*0.8)]*2),
    # T.Lambda(random_square_channel),
    T.ColorJitter(0.4,0.4,0.4),
    T.RandomGrayscale(0.02),
    T.RandomHorizontalFlip(1.0),
])
# interpolation = T.InterpolationMode.NEAREST
tr = T.Compose([
    T.ToTensor(),
    T.Resize(RESIZE[0]),
    T.RandomCrop(RESIZE),
    T.Lambda(lambda x: x[[0]*3] if len(x)==1 else x[:3]),
])

dataset = ImageFolder(images_path,transform=tr)
random_state = 123
# torch.random.manual_seed(random_state)
# random.seed(random_state)

train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    num_workers=8,
    batch_size=32,
    random_state=random_state,
    prefetch_factor=1,
    pin_memory=True
)
len(train_dataset),len(test_dataset),len(dataset.classes)

In [None]:
import matplotlib.pyplot as plt
from random import randint

# Set up a 4x4 grid for displaying images
plt.figure(figsize=(10,10))

for i in range(4):
    for j in range(4):
        index = randint(0, len(dataset) - 1)       # Random index from dataset
        sample = dataset[index]                    # Select a random sample
        image, label = sample[0], sample[1]        # Separate image and label
        image=runtime_tr(image)
        plt.subplot(4,4,i*4+j+1)
        plt.title(dataset.classes[label]+" "+str(list(sample[0].shape)))
        # Display image on the selected subplot
        plt.imshow(T.ToPILImage()(image))
        plt.axis("off")                             # Hide axes for clean view

plt.tight_layout()
plt.show()

In [None]:
from kemsekov_torch.residual import ResidualBlock, Residual
import torch.nn as nn
from kemsekov_torch.attention import EfficientSpatialChannelAttention

class VAE(torch.nn.Module):
    def __init__(self, in_channels,num_classes):
        super().__init__()
        intermidate_size = 64
        hidden_size = 256
        latent_size=512
        
        common = dict(
            kernel_size=4,
            stride=2,
            normalization='group'
        )
        repeats = 2
        self.enc=nn.Sequential(
            ResidualBlock(in_channels,[intermidate_size]*repeats,**common),
            EfficientSpatialChannelAttention(intermidate_size),
            
            ResidualBlock(intermidate_size,[intermidate_size*2]*repeats,**common),
            EfficientSpatialChannelAttention(intermidate_size*2),
            
            ResidualBlock(intermidate_size*2,[hidden_size]*repeats,**common),
            EfficientSpatialChannelAttention(hidden_size),
            
            ResidualBlock(hidden_size,[hidden_size]*repeats,**common),
            EfficientSpatialChannelAttention(hidden_size),
            
            ResidualBlock(hidden_size,[latent_size,latent_size],**common),
            EfficientSpatialChannelAttention(latent_size),
            
            ResidualBlock(latent_size,[latent_size,2*latent_size],**common),
            EfficientSpatialChannelAttention(2*latent_size),
            
            nn.Conv2d(2*latent_size,2*latent_size,1)
        )
        self.classify = nn.Linear(latent_size,num_classes)
   
    @torch.jit.export
    def encode(self,x):
        mean,logvar = self.enc(x).chunk(2,1)
        return mean,logvar
    
    def sample(self,mu,logvar,std : float = 1.0):
        return mu+torch.randn_like(mu)*logvar.exp()*std
    
    def forward(self,x):
        mu,logvar = self.encode(x)
        return self.sample(mu,logvar),mu,logvar

model = VAE(3,len(dataset.classes))
model(torch.randn((1,3,128,128)))[1].shape

In [None]:
from kemsekov_torch.train import *
from kemsekov_torch.common_modules import kl_divergence
import torch.nn as nn
from torchmetrics import F1Score
import warnings
warnings.filterwarnings("ignore")

triplet = torch.nn.TripletMarginLoss(
    1,
    2,
    swap=False,
    reduction='mean'
)
CE = torch.nn.CrossEntropyLoss()

beta=1/2

f1 = F1Score('multiclass',num_classes=len(dataset.classes))

def loss_and_metric(model : nn.Module, batch):
    ims,label = batch[0],batch[1]
    
    ims_tr = runtime_tr(ims)
    ims = runtime_tr(ims)
    
    sample_latent1,mu1,logvar1 = model(ims)
    sample_latent2,mu2,logvar2 = model(ims_tr)
    
    sample_latent1 = sample_latent1.mean([-1,-2])
    mu1=mu1.mean([-1,-2])
    sample_latent2 = sample_latent2.mean([-1,-2])
    
    pred_label_sample = model.classify(sample_latent1.detach())
    with torch.no_grad():
        pred_label_mu = model.classify(mu1.detach())
    
    sample_latent_perm = sample_latent1[torch.randperm(len(sample_latent1))]
    
    # to make sure latents have mean 0 std 1 and approximately normal dist
    kl = kl_divergence(mu1,logvar1,[-1,-2,-3])+kl_divergence(mu2,logvar2,[-1,-2,-3])
    triplet_loss = triplet(sample_latent1,sample_latent2,sample_latent_perm)
    loss = triplet_loss+beta*kl + CE(pred_label_sample,label)
    
    metric = {
        'kl':kl,
        'triplet_loss':triplet_loss,
        'f1_sample':f1(pred_label_sample.softmax(-1).cpu(),label.cpu()),
        'f1_mu':f1(pred_label_mu.softmax(-1).cpu(),label.cpu()),
    }
    
    return loss,metric

epochs=200

optim = torch.optim.AdamW(model.parameters(),1e-3)
sh = torch.optim.lr_scheduler.CosineAnnealingLR(optim,len(train_loader)*epochs)
path = 'runs/image_emb'
_ = train(
    model,
    train_loader,
    test_loader,
    loss_and_metric,
    path,
    f"{path}/last",
    num_epochs=epochs,
    save_on_metric_improve=['f1_mu'],
    accelerate_args={
        'mixed_precision':'bf16',
        'dynamo_backend':'inductor'
    },
    gradient_clipping_max_norm=1,
    optimizer = optim,
    scheduler = sh,
    ema_args={
        'beta':0.999,
        'power':1,
    }
)

In [None]:
import random
import PIL
import PIL.Image
import os

random.seed(None)
pixel_art_d='/home/vlad/Documents/image-classification/sheep/'
im_path = random.choice(os.listdir(pixel_art_d))
im_path=f"{pixel_art_d}/{im_path}"
# im_path='/home/vlad/Downloads/cat_dog_2.png'
im = PIL.Image.open(im_path).convert("RGB")
im = T.ToTensor()(im)

# id = random.randint(0,len(test_dataset)-1)
# im,label = test_dataset[id]

print(label)
path = 'runs/image_emb/'

m = torch.jit.load(os.path.join(path,"model.pt"),map_location='cpu')
m = load_checkpoint(m,path,-1).cpu().eval()

std = 2
with torch.no_grad():
    mu,logvar = m.encode(im[None,:])
    print("Class",dataset.classes[label])
    sample = m.sample(mu,logvar,std)[0]

print("input",im.shape)
print("latent",mu[0].shape)
resize=T.Resize(512,interpolation=T.InterpolationMode.NEAREST_EXACT)

plt.figure(figsize=(15,15))
plt.subplot(1,3,1)
plt.imshow(T.ToPILImage()(resize(im)))
plt.axis('off')
plt.title("input")

In [None]:
from kemsekov_torch.rotary_emb import RotEmb
import torch

r = RotEmb()
x = torch.randn((3,32,16,8,16))
y = torch.randn((3,64,8,8,16))

r.train()
r(x)
r.eval()
r(y)
r(y)
r.max_seq_len1d

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load the tokenizer and model
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", torch_dtype=torch.bfloat16)
# Prepare the input prompt
prompt = "What is the capital of France?"

# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate a response
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.7)

# Decode and print the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
