# Self-supervised and multi-modal representation Learning: Notebook 3

Here we will align the image and light curve representations with contrastive learning. Optionally, we can use the light curve encoder we trained previously.

## Multi-modal contrastive learning with CLIP

### Light curve encoding via masked self-supervised learning

In [2]:
import os, sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from PIL import Image

import torch
from einops import rearrange

Load host images and inspect shape.

In [3]:
dir_host_imgs = "../data/ZTFBTS/hostImgs/"
host_imgs = []

for filename in os.listdir(dir_host_imgs):
    file_path = os.path.join(dir_host_imgs, filename)
    if file_path.endswith(".png"):
        host_img = Image.open(file_path).convert('RGB')
        host_img = np.asarray(host_img)
        host_imgs.append(host_img)

host_imgs = np.array(host_imgs)

host_imgs = torch.from_numpy(host_imgs).float()
host_imgs = rearrange(host_imgs, 'b h w c -> b c h w')

# Normalize
host_imgs /= 255.0

Load light curves and pre-process them just like in the previous notebook.

In [4]:
dir_light_curves = "../data/ZTFBTS/light-curves/"

def open_light_curve_csv(filename):
    file_path = os.path.join(dir_light_curves, filename)
    df = pd.read_csv(file_path)
    return df

light_curve_df = open_light_curve_csv("ZTF18aailmnv.csv")
light_curve_df.head()

Unnamed: 0,time,mag,magerr,band
0,58312.219097,20.132299,0.25236,R
1,58319.205984,18.713728,0.104188,g
2,58319.224942,18.808235,0.09266,R
3,58320.174525,18.467438,0.09392,g
4,58324.179444,18.514769,0.117073,R


In [5]:
band = 'R'
n_max_obs = 50

lightcurve_files = os.listdir(dir_light_curves)

# For entries with > n_max_obs observations, randomly sample n_max_obs observations (hmjd, mag, and magerr with same sample) from the light curve
# Pad the entries to n_max_obs observations with zeros and create a mask array
mask_list = []
mag_list = []
magerr_list = []
time_list = []

for filename in tqdm(lightcurve_files):
    if filename.endswith(".csv"):
        light_curve_df = open_light_curve_csv(filename)
        
        # Make sure the csv contains 'time', 'mag', 'magerr', and 'band' columns
        if not all(col in light_curve_df.columns for col in ['time', 'mag', 'magerr', 'band']):
            continue
        
        bands = light_curve_df['band'].unique()
        df_band = light_curve_df[light_curve_df['band'] == band]

        if len(df_band['mag']) > n_max_obs:
            mask = np.ones(n_max_obs, dtype=bool)
            mask_list.append(mask)
            indices = np.random.choice(len(df_band['mag']), n_max_obs)
            time = df_band['time'].values[indices]
            mag = df_band['mag'].values[indices]
            magerr = df_band['magerr'].values[indices]
        else:
            mask = np.zeros(n_max_obs, dtype=bool)
            mask[:len(df_band['mag'])] = True
            mask_list.append(mask)

            # Pad the arrays with zeros
            time = np.pad(df_band['mag'], (0, n_max_obs - len(df_band['mag'])), 'constant')
            mag = np.pad(df_band['magerr'], (0, n_max_obs - len(df_band['magerr'])), 'constant')
            magerr = np.pad(df_band['time'], (0, n_max_obs - len(df_band['time'])), 'constant')
            
        time_list.append(time)
        mag_list.append(mag)
        magerr_list.append(magerr)

time_ary = np.array(time_list)
mag_ary = np.array(mag_list)
magerr_ary = np.array(magerr_list)
mask_ary = np.array(mask_list)

100%|██████████| 5170/5170 [00:03<00:00, 1579.60it/s]


In [6]:
# Inspect shapes
time_ary.shape, mag_ary.shape, magerr_ary.shape, mask_ary.shape

((5170, 50), (5170, 50), (5170, 50), (5170, 50))

Looks good so far!

### Image encoder

In [7]:
import torch
import torch.nn as nn

In [8]:
class Residual(nn.Module):
    """
    A residual block that adds the input to the output of a function.
    """
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        # Apply the function and add the input to the result
        return self.fn(x) + x
    

class ConvMixer(nn.Module):
    """
    ConvMixer model, a simple and efficient convolutional neural network.
    """
    def __init__(self, dim, depth, channels=1, kernel_size=5, patch_size=8, n_out=128):
        super(ConvMixer, self).__init__()

        # Initial convolution layer
        self.net = nn.Sequential(
            nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size, bias=False),
            nn.GELU(),
            nn.BatchNorm2d(dim),
        )

        # Adding depth number of ConvMixer layers
        for _ in range(depth):
            self.net.append(nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            ))

        # Projection head
        self.projection = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(dim, 1024),
            nn.GELU(),
            nn.Linear(1024, n_out)
        )

    def forward(self, x):
        # Forward pass through the network
        x = self.net(x)
        x = self.projection(x)
        return x


In [9]:
convmixer = ConvMixer(dim=60, depth=8, channels=3, kernel_size=5, patch_size=8, n_out=128)
convmixer(host_imgs[:4]).shape

torch.Size([4, 128])

In [10]:
import math

import sys
sys.path.append('../')
from models.transformer_utils import Transformer

class TimePositionalEncoding(nn.Module):

    def __init__(self, d_emb):
        """
        Inputs
            d_model - Hidden dimensionality.
        """
        super().__init__()
        self.d_emb = d_emb

    def forward(self, t):
        pe = torch.zeros(t.shape[0], t.shape[1], self.d_emb).to(t.device)  # (B, T, D)
        div_term = torch.exp(torch.arange(0, self.d_emb, 2).float() * (-math.log(10000.0) / self.d_emb))[None, None, :].to(t.device)  # (1, 1, D / 2)
        t = t.unsqueeze(2)  # (B, 1, T)
        pe[:, :, 0::2] = torch.sin(t * div_term)  # (B, T, D / 2)
        pe[:, :, 1::2] = torch.cos(t * div_term)  # (B, T, D / 2)
        return pe  # (B, T, D)

class TransformerWithTimeEmbeddings(nn.Module):
    """
    Transformer for classifying sequences
    """

    def __init__(self, n_out, **kwargs):
        """
        :param emb: Embedding dimension
        :param heads: nr. of attention heads
        :param depth: Number of transformer blocks
        :param seq_length: Expected maximum sequence length
        :param num_classes: Number of classes.
        :param max_pool: If true, use global max pooling in the last layer. If false, use global
                         average pooling.
        """
        super().__init__()
        
        self.embedding_mag = nn.Linear(in_features=1, out_features=kwargs['emb'])
        self.embedding_t = TimePositionalEncoding(kwargs['emb'])
        self.transformer = Transformer(**kwargs)

        self.projection = nn.Linear(kwargs['emb'], n_out)

    def forward(self, x, t, mask=None):
        """
        :param x: A batch by sequence length integer tensor of token indices.
        :return: predicted log-probability vectors for each token based on the preceding tokens.
        """
        t = t - t[:, 0].unsqueeze(1)
        t_emb = self.embedding_t(t)
        x = self.embedding_mag(x) + t_emb
        x = self.transformer(x, mask)  # (B, T, D)
        
        # Zero out the masked values
        x = x * mask[:, :, None]

        # Max pool
        x = x.max(dim=1)[0]
        
        x = self.projection(x)
        return x

transformer = TransformerWithTimeEmbeddings(n_out=128, emb=128, heads=1, depth=1)

In [11]:
# Time and mag tensors
time = torch.from_numpy(time_ary).float()
mag = torch.from_numpy(mag_ary).float()
mask = torch.from_numpy(mask_ary).bool()

In [12]:
# Pass a batch through
transformer(mag[:4][..., None], time[:4], mask[:4]).shape

torch.Size([4, 128])

### Contrastive-style losses

Bidirection (symmetric between modalities) InfoNCE loss to compute alignment between image and light curve representations.

In [13]:
import torch.nn.functional as F
import pytorch_lightning as pl

  Referenced from: <7702F607-92FA-3D67-9D09-0710D936B85A> /opt/homebrew/Caskroom/miniforge/base/envs/cfm/lib/python3.10/site-packages/torchvision/image.so
  warn(


In [14]:
from torch.utils.data import TensorDataset, DataLoader, random_split

val_fraction = 0.1
batch_size = 64
n_samples_val = int(val_fraction * mag.shape[0])

dataset = TensorDataset(host_imgs, mag, time, mask)

dataset_train, dataset_val = random_split(dataset, [mag.shape[0] - n_samples_val, n_samples_val])
train_loader = DataLoader(dataset_train, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True)
val_loader = DataLoader(dataset_val, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=False)

In [15]:
def clip_loss(image_embeddings, text_embeddings, temperature=1):
    
    log_softmax = nn.LogSoftmax(dim=1)

    logits = (text_embeddings @ image_embeddings.T) / temperature
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T
    targets = F.softmax((images_similarity + texts_similarity) / 2 * temperature, dim=-1)
    images_loss = (-targets.T * log_softmax(logits.T)).sum(1)
    texts_loss = (-targets * log_softmax(logits)).sum(1)
    return (images_loss + texts_loss) / 2.0

def sigmoid_loss(image_embeds, text_embeds, logit_scale=1., logit_bias=2.73):

    bs = text_embeds.shape[0]
    
    labels = 2 * torch.eye(bs) - torch.ones((bs, bs))
    labels = labels.to(text_embeds.device)

    logits = text_embeds @ image_embeds.t() * logit_scale + logit_bias
    logits = logits.to(torch.float64)
    
    positive_loss = -torch.mean(torch.log(torch.sigmoid(labels * logits)))
    
    shifted_image_embeds = torch.roll(image_embeds, 1, dims=0)
    negative_logits = text_embeds @ shifted_image_embeds.t() * logit_scale + logit_bias    
    negative_loss = -torch.mean(torch.log(1 - torch.sigmoid(negative_logits)))
    
    loss = positive_loss + negative_loss
    
    return loss

In [22]:
class LightCurveImageCLIP(pl.LightningModule):
    def __init__(self, 
                 enc_dim=64,
                 temperature=10.,
                 transformer_kwargs={"n_out":128, "emb":128, "heads":2, "depth":2}, 
                 conv_kwargs = {'dim': 60, 'depth': 4, 'channels': 3, 'kernel_size': 5, 'patch_size': 8, 'n_out': 128}, 
                 optimizer_kwargs={}, lr=1e-4):
        super().__init__()

        self.lr = lr
        self.optimizer_kwargs = optimizer_kwargs
        self.enc_dim = enc_dim

        # Make temperature a learnable parameter
        self.logit_scale = nn.Parameter(torch.tensor(math.log(temperature)), requires_grad=True)
        self.logit_bias = nn.Parameter(torch.tensor(-10.), requires_grad=True)

        self.lightcurve_encoder = TransformerWithTimeEmbeddings(**transformer_kwargs)
        self.image_encoder = ConvMixer(**conv_kwargs)

        self.lightcurve_projection = nn.Linear(transformer_kwargs['n_out'], enc_dim)
        self.image_projection = nn.Linear(conv_kwargs['n_out'], enc_dim)

    def forward(self, x_img, x_lc, t_lc, mask_lc=None):
        
        # Light curve encoder
        x_lc = self.lightcurve_embeddings_with_projection(x_lc, t_lc, mask_lc)
        
        # Image encoder
        x_img = self.image_embeddings_with_projection(x_img)
    
        # Normalized embeddings
        return x_img, x_lc
    
    def image_embeddings_with_projection(self, x_img):
        x_img = self.image_encoder(x_img)
        x_img = self.image_projection(x_img)
        return x_img / x_img.norm(dim=-1, keepdim=True)
    
    def lightcurve_embeddings_with_projection(self, x_lc, t_lc, mask_lc=None):
        x_lc = x_lc[..., None]
        x_lc = self.lightcurve_encoder(x_lc, t_lc, mask_lc)
        x_lc = self.lightcurve_projection(x_lc)
        return x_lc / x_lc.norm(dim=-1, keepdim=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
        return {"optimizer": optimizer}
    
    def training_step(self, batch, batch_idx):
        x_img, x_lc, t_lc, mask_lc = batch
        x_img, x_lc = self(x_img, x_lc, t_lc, mask_lc)
        loss = sigmoid_loss(x_img, x_lc, self.logit_scale, self.logit_bias).mean()
        # loss = clip_loss(x_img, x_lc, self.logit_scale,).mean()
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x_img, x_lc, t_lc, mask_lc = batch
        x_img, x_lc = self(x_img, x_lc, t_lc, mask_lc)
        loss = sigmoid_loss(x_img, x_lc, self.logit_scale, self.logit_bias).mean()
        # loss = clip_loss(x_img, x_lc, self.logit_scale, ).mean()
        self.log("val_loss", loss, on_epoch=True)
        return loss

In [23]:
clip_model = LightCurveImageCLIP(temperature=10., lr=1e-4)

x_img, x_lc, t_lc, mask_lc = next(iter(train_loader))
x_img, x_lc = clip_model(x_img, x_lc, t_lc, mask_lc)

sigmoid_loss(x_img, x_lc, clip_model.logit_scale).mean()

tensor(5.1890, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [24]:
trainer = pl.Trainer(max_epochs=10, accelerator='cpu')
trainer.fit(model=clip_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                  | Type                          | Params
------------------------------------------------------------------------
0 | lightcurve_encoder    | TransformerWithTimeEmbeddings | 544 K 
1 | image_encoder         | ConvMixer                     | 227 K 
2 | lightcurve_projection | Linear                        | 8.3 K 
3 | image_projection      | Linear                        | 8.3 K 
------------------------------------------------------------------------
787 K     Trainable params
0         Non-trainable params
787 K     Total params
3.151     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]