In [1]:
import lance
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

# Model Architectures

I have choosen two completely different model architectures to test my POC on. These two are entirely different, namely, a variational autoencoder and a contrastive learning-based framework (simCLR).

There are some points to make for choosing a generative framework like a VAE:

## Variational Autoencoder

### Pro's
- It allows me to make sliders on latent factors for **explainability** purposes
- The latent space is inherently smooth
- I could enforce certain latent factors to encode a specific characteristic by creating subbranches from the latent variables towards subgoals (e.g. size, circularity, class [conditional VAE]).

### Drawbacks

- Reconstructions will be blurred (averaged) and are not the best in quality
- Learning might prove tricky 

## Contrastive learning

- Might prove easier to train
- Embeddings are much more powerful

## A combination

In the ideal case, I would combine the strenght of both. A simCLR encoder for embedding generation and a variational decoder on top for interpretation.

In [47]:
root_data = '/home/sam/SCI/cellenONE_project/datasets'

lds = lance.dataset(
    os.path.join(
        root_data,
        'test_Backup_SCI.lance'
    )
)

dtps = lds.to_table(
    columns=['cell_diff_crop'],
    batch_size=24
).to_batches()

# dtps['cell_diff_crop'] = dtps['cell_diff_crop'].apply(lambda x: np.array(
#     x.tolist(),
#     dtype=float
# ))

In [79]:
from CellVision.data.preprocessing.augmentations import SimpleAugmentor
from PIL import Image

In [88]:
for batch in dtps:
    break

In [97]:
ImageToTensor = T.Compose(
    [T.ToTensor()]
)

cells = [
    Image.fromarray(
        np.array(img, dtype=float)
    ).convert('L') for img in batch["cell_diff_crop"].tolist()]

aug_1 = torch.cat([SimpleAugmentor(x) for x in cells])
aug_2 = torch.cat([SimpleAugmentor(x) for x in cells])
targets = torch.cat([ImageToTensor(x) for x in cells])


{
    'targets': targets.shape, # add channel dimension
    'aug_1': aug_1.shape,
    'aug_2': aug_2.shape
}

{'targets': torch.Size([24, 64, 64]),
 'aug_1': torch.Size([24, 64, 64]),
 'aug_2': torch.Size([24, 64, 64])}

In [None]:
import torchvision.transforms as T



ImageToTensor(cells[0])

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [86]:
torch.cat(augs).shape

torch.Size([24, 64, 64])

In [73]:
torch.tensor(cell).shape

torch.Size([64, 64])

## 0. Helper blocks

In [18]:
"""
Three backbones (PyTorch) for small grayscale images (25-45 px).
- BetaVAE: encoder -> latent (mu, logvar) -> decoder
- ContrastiveModel: encoder (GAP) -> projection head (for SimCLR/BYOL/SimSiam)
- Hybrid: contrastive encoder -> bottleneck VAE on top (decode from z)

Design goals / notes:
- Use Global Average Pooling (GAP) to reduce positional encoding
- Small capacity to match tiny images (~40x40)
- Inputs: 1-channel grayscale; adjust `in_ch` if needed
- Recommended use: resize/pad images to a stable size (e.g., 40x40 or pad to 56x56 and random crop)

"""
# ---------------------------
# Utility blocks
# ---------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel=3, stride=1, padding=1, use_bn=True):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, kernel, stride, padding, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.ReLU(inplace=True))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class GlobalAvgPool2d(nn.Module):
    def forward(self, x):
        return F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)


# ---------------------------
# Encoder (shared ideas)
# ---------------------------
class SmallEncoderGAP(nn.Module):
    """Small convolutional encoder that ends with Global Average Pooling.
    Output is a vector (features) suitable for projection head or VAE bottleneck.
    """
    def __init__(self, in_ch=1, base_filters=16, out_feat=32):
        super().__init__()
        # design: conv -> conv(strided) -> conv(strided) -> GAP -> fc
        self.conv1 = ConvBlock(in_ch, base_filters, stride=1)
        self.conv2 = ConvBlock(base_filters, base_filters * 2, stride=2)  # downsample
        self.conv3 = ConvBlock(base_filters * 2, base_filters * 4, stride=2)  # downsample
        self.conv4 = ConvBlock(base_filters * 4, base_filters * 8, stride=2)  # downsample
        self.gap = GlobalAvgPool2d()
        self.fc = nn.Sequential(
            nn.Linear(base_filters * 8, out_feat),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # x: (B, C, H, W) expected small H,W
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.gap(x)  # (B, channels)
        x = self.fc(x)
        return x


# ---------------------------
# Decoder for VAE (simple upsampling)
# ---------------------------
class SmallDecoder(nn.Module):
    """Simple decoder that maps latent z to image of specified size.
    Uses linear -> reshape -> convtranspose / upsample conv.
    """
    def __init__(self, z_dim=16, out_ch=1, base_filters=16, out_size=64):
        super().__init__()
        # compute a small spatial size to reshape into
        # we'll reshape into (base_filters*4, h', w') where h'*w' approx = out_size//4 square
        self.out_size = out_size
        # choose a fixed small spatial grid: e.g., 5x5 if out_size ~40 (5*8=40 with upsample)
        # We'll use a simple decoder that upsamples twice.
        hidden_ch = base_filters * 8
        self.fc = nn.Linear(z_dim, hidden_ch * 8 * 8)

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            ConvBlock(hidden_ch, base_filters * 4),
            nn.Upsample(scale_factor=2, mode='nearest'),
            ConvBlock(base_filters * 4, base_filters * 2),
            nn.Upsample(scale_factor=2, mode='nearest'),
            ConvBlock(base_filters * 2, base_filters),
            nn.Conv2d(base_filters, out_ch, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, z):
        # z: (B, z_dim)
        B = z.size(0)
        x = self.fc(z)
        x = x.view(B, -1, 8, 8)  # (B, hidden_ch, 5, 5)
        x = self.up(x)
        # now x is larger than 40x40 depending on upsample specifics; center-crop / interpolate to out_size
        # x = F.interpolate(x, size=(self.out_size, self.out_size), mode='bilinear', align_corners=False)
        return x

# ---------------------------
# Contrastive backbone + projection head
# ---------------------------
class ProjectionHead(nn.Module):
    """Simple MLP projection head used in SimCLR/BYOL etc.
    maps encoder features to projection space for contrastive loss.
    """
    def __init__(self, in_dim=128, hidden_dim=128, out_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:


img_b = dtps['cell_diff_crop'][:5].tolist()



In [19]:
test = torch.tensor(dtps['cell_diff_crop'][:5].tolist()).unsqueeze(dim=1).float()
test.shape

torch.Size([5, 1, 64, 64])

In [22]:
encoder = SmallEncoderGAP(
    in_ch=1,
    base_filters=16,
    out_feat=16
)
decoder = SmallDecoder(
    z_dim=16,
    out_ch=1,
    base_filters=16,
    out_size=64
)

- torch.Size([5, 1, 64, 64])
- Conv1: torch.Size([5, 16, 64, 64])
- Conv2: torch.Size([5, 32, 32, 32])
- Conv3: torch.Size([5, 64, 16, 16])
- Conv4: torch.Size([5, 128, 8, 8])
- GAP: torch.Size([5, 128])
- FC: torch.Size([5, 16])

In [26]:
print(test.shape)

encoded = encoder(test)
print(encoded.shape)

decoded = decoder(encoded)
print(decoded.shape)

torch.Size([5, 1, 64, 64])
torch.Size([5, 16])
torch.Size([5, 1, 64, 64])


# 1. beta Variational Autoencoder

In [None]:
class BetaVAE(nn.Module):
    """Small beta-VAE using SmallEncoderGAP as encoder trunk.
    Encoder returns mu, logvar. Decoder reconstructs canonical centered image.
    """
    def __init__(self, in_ch=1, z_dim=8, base_filters=16, hidden_feat=16, out_size=64, beta=1.0):
        super().__init__()
        self.beta = beta
        self.encoder_trunk = SmallEncoderGAP(in_ch=in_ch, base_filters=base_filters, out_feat=hidden_feat)
        self.fc_mu = nn.Linear(hidden_feat, z_dim)
        self.fc_logvar = nn.Linear(hidden_feat, z_dim)
        self.decoder = SmallDecoder(z_dim=z_dim, out_ch=in_ch, base_filters=base_filters, out_size=out_size)

    def encode(self, x):
        h = self.encoder_trunk(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, z, mu, logvar

## Test

In [29]:
vae = BetaVAE(
    in_ch=1,
    z_dim=8,
    base_filters=16,
    hidden_feat=16,
    out_size=64,
    beta=1
)

print(test.shape)
reconstructed, mu, logvar = vae(test)
print(reconstructed.shape, mu.shape, logvar.shape)

torch.Size([5, 1, 64, 64])
torch.Size([5, 1, 64, 64]) torch.Size([5, 8]) torch.Size([5, 8])


# 2. Contrastive Learning Model (simCLR)

Fix the down and upsampling towards my 68x68 dimensions so as to also remove the interpolate.

Add the following as hyperparameters for VAE:
- channels (filters)
- layers
- latent space dimension

In [None]:
class ContrastiveModel(nn.Module):
    """Encoder trunk + projection head. Use with SimCLR/BYOL training loops.
    """
    def __init__(self, in_ch=1, base_filters=16, feat_dim=128, proj_hidden=128, proj_dim=64):
        super().__init__()
        self.encoder_trunk = SmallEncoderGAP(in_ch=in_ch, base_filters=base_filters, out_feat=feat_dim)
        self.proj = ProjectionHead(in_dim=feat_dim, hidden_dim=proj_hidden, out_dim=proj_dim)

    def forward(self, x):
        feats = self.encoder_trunk(x)
        z = self.proj(feats)
        # optionally normalize for NT-Xent
        z = F.normalize(z, dim=1)
        return feats, z

## Test

# 3. The Hybrid (skip for now)

In [None]:
# ---------------------------
# Hybrid: contrastive encoder + VAE on top
# ---------------------------
class ContrastiveVAE(nn.Module):
    """Hybrid model: encoder trunk (contrastive) -> bottleneck (mu/logvar) -> decoder
    Idea: train encoder with contrastive loss; then attach VAE head to same features and train decoder (or fine-tune jointly).
    """
    def __init__(self, in_ch=1, base_filters=16, feat_dim=128, z_dim=64, decoder_base=16, out_size=40, beta=1.0):
        super().__init__()
        self.beta = beta
        # encoder trunk identical to contrastive backbone
        self.encoder_trunk = SmallEncoderGAP(in_ch=in_ch, base_filters=base_filters, out_feat=feat_dim)
        # VAE projection
        self.fc_mu = nn.Linear(feat_dim, z_dim)
        self.fc_logvar = nn.Linear(feat_dim, z_dim)
        # decoder decodes from z
        self.decoder = SmallDecoder(z_dim=z_dim, out_ch=in_ch, base_filters=decoder_base, out_size=out_size)

    def encode_to_feats(self, x):
        return self.encoder_trunk(x)

    def encode(self, x):
        h = self.encode_to_feats(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        feats = self.encode_to_feats(x)
        mu = self.fc_mu(feats)
        logvar = self.fc_logvar(feats)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return feats, recon, mu, logvar

## Test