In [1]:
import os, sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image

In [38]:
dir_host_imgs = "../data/ZTFBTS/hostImgs/"
host_imgs = []

for filename in os.listdir(dir_host_imgs):
    file_path = os.path.join(dir_host_imgs, filename)
    # filename ends with .png
    if file_path.endswith(".png"):
        host_img = Image.open(file_path).convert('RGB')
        host_img = np.asarray(host_img)
        host_imgs.append(host_img)

host_imgs = np.array(host_imgs)
print(host_imgs.shape)

(5170, 60, 60, 3)


In [6]:
dir_light_curves = "../data/ZTFBTS/light-curves/"

def open_light_curve_csv(filename):
    file_path = os.path.join(dir_light_curves, filename)
    df = pd.read_csv(file_path)
    return df

light_curve_df = open_light_curve_csv("ZTF18aailmnv.csv")
light_curve_df.head()

Unnamed: 0,time,mag,magerr,band
0,58312.219097,20.132299,0.25236,R
1,58319.205984,18.713728,0.104188,g
2,58319.224942,18.808235,0.09266,R
3,58320.174525,18.467438,0.09392,g
4,58324.179444,18.514769,0.117073,R


In [32]:
from tqdm import tqdm

lightcurve_files = os.listdir(dir_light_curves)
band = 'R'
n_max_obs = 100

# For entries with > n_max_obs observations, randomly sample n_max_obs observations (hmjd, mag, and magerr with same sample) from the light curve
# Pad the entries to n_max_obs observations with zeros and create a mask array
mask_list = []
mag_list = []
magerr_list = []
time_list = []

for filename in tqdm(lightcurve_files):
    if filename.endswith(".csv"):
        light_curve_df = open_light_curve_csv(filename)
        
        # Make sure the csv contains 'time', 'mag', 'magerr', and 'band' columns
        if not all(col in light_curve_df.columns for col in ['time', 'mag', 'magerr', 'band']):
            continue
        
        bands = light_curve_df['band'].unique()
        df_band = light_curve_df[light_curve_df['band'] == band]

        if len(df_band['mag']) > n_max_obs:
            mask = np.ones(n_max_obs, dtype=bool)
            mask_list.append(mask)
            indices = np.random.choice(len(df_band['mag']), n_max_obs)
            time = df_band['time'].values[indices]
            mag = df_band['mag'].values[indices]
            magerr = df_band['magerr'].values[indices]
        else:
            mask = np.zeros(n_max_obs, dtype=bool)
            mask[:len(df_band['mag'])] = True
            mask_list.append(mask)

            # Pad the arrays with zeros
            time = np.pad(df_band['mag'], (0, n_max_obs - len(df_band['mag'])), 'constant')
            mag = np.pad(df_band['magerr'], (0, n_max_obs - len(df_band['magerr'])), 'constant')
            magerr = np.pad(df_band['time'], (0, n_max_obs - len(df_band['time'])), 'constant')
            
        time_list.append(time)
        mag_list.append(mag)
        magerr_list.append(magerr)

time_ary = np.array(time_list)
mag_ary = np.array(mag_list)
magerr_ary = np.array(magerr_list)
mask_ary = np.array(mask_list)
        

100%|██████████| 5170/5170 [00:03<00:00, 1627.92it/s]


In [39]:
import torch
import torch.nn as nn

In [44]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x
    

class ConvMixer(nn.Module):
    def __init__(self, dim, depth, channels=1, kernel_size=5, patch_size=8, n_out=2):
        super(ConvMixer, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size, bias=False),
            nn.GELU(),
            nn.BatchNorm2d(dim),
            *[nn.Sequential(
                    Residual(nn.Sequential(
                        nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                        nn.GELU(),
                        nn.BatchNorm2d(dim)
                    )),
                    nn.Conv2d(dim, dim, kernel_size=1),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
            ) for i in range(depth)],
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(dim, 2048),
            nn.GELU(),
            nn.Linear(2048, 1024),
            nn.GELU(),
            nn.Linear(1024, n_out)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
host_imgs = torch.from_numpy(host_imgs).float()

In [49]:
from einops import rearrange
host_imgs = rearrange(host_imgs, 'b h w c -> b c h w')

In [70]:
convmixer = ConvMixer(dim=60, depth=8, channels=3, kernel_size=5, patch_size=8, n_out=128)
convmixer(host_imgs[:2]).shape

torch.Size([2, 128])

In [71]:
import math

import sys
sys.path.append('../')
from models.transformer_utils import Transformer

class TimePositionalEncoding(nn.Module):

    def __init__(self, d_emb):
        """
        Inputs
            d_model - Hidden dimensionality.
        """
        super().__init__()
        self.d_emb = d_emb

    def forward(self, t):
        pe = torch.zeros(t.shape[0], t.shape[1], self.d_emb).to(t.device)  # (B, T, D)
        div_term = torch.exp(torch.arange(0, self.d_emb, 2).float() * (-math.log(10000.0) / self.d_emb))[None, None, :].to(t.device)  # (1, 1, D / 2)
        t = t.unsqueeze(2)  # (B, 1, T)
        pe[:, :, 0::2] = torch.sin(t * div_term)  # (B, T, D / 2)
        pe[:, :, 1::2] = torch.cos(t * div_term)  # (B, T, D / 2)
        return pe  # (B, T, D)

class TransformerWithTimeEmbeddings(nn.Module):
    """
    Transformer for classifying sequences
    """

    def __init__(self, n_out=1, **kwargs):
        """
        :param emb: Embedding dimension
        :param heads: nr. of attention heads
        :param depth: Number of transformer blocks
        :param seq_length: Expected maximum sequence length
        :param num_classes: Number of classes.
        :param max_pool: If true, use global max pooling in the last layer. If false, use global
                         average pooling.
        """
        super().__init__()
        
        self.embedding_mag = nn.Linear(in_features=1, out_features=kwargs['emb'])
        self.embedding_t = TimePositionalEncoding(kwargs['emb'])
        self.transformer = Transformer(**kwargs)
        self.projection = nn.Linear(in_features=kwargs['emb'], out_features=n_out)

    def forward(self, x, t, mask=None):
        """
        :param x: A batch by sequence length integer tensor of token indices.
        :return: predicted log-probability vectors for each token based on the preceding tokens.
        """
        t = t - t[:, 0].unsqueeze(1)
        t_emb = self.embedding_t(t)
        x = self.embedding_mag(x) + t_emb
        x = self.transformer(x, mask)
        x = self.projection(x)

        # Pool along sequence dim
        x = x.mean(dim=1)

        # Projection head
        x = nn.Linear(x.shape[-1], 128)(x)

        return x

transformer = TransformerWithTimeEmbeddings(n_out=64, emb=128, heads=2, depth=2)

In [67]:
# Time and mag tensors
time = torch.from_numpy(time_ary).float()
mag = torch.from_numpy(mag_ary).float()
mask = torch.from_numpy(mask_ary).bool()

In [68]:
# Pass through transformer
transformer(time[:2][..., None], mag[:2], mask[:2]).shape

torch.Size([2, 128])

In [77]:
import torch.nn.functional as F

def _compute_losses(image_embeddings, text_embeddings):

    log_softmax = nn.LogSoftmax(dim=1)
    temperature = 0.1

    logits = (text_embeddings @ image_embeddings.T) / temperature
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T
    targets = F.softmax(
        (images_similarity + texts_similarity) / 2 * temperature, dim=-1
    )
    images_loss = (-targets.T * log_softmax(logits.T)).sum(1)
    texts_loss = (-targets * log_softmax(logits)).sum(1)
    return (images_loss + texts_loss) / 2.0

_compute_losses(transformer(time[:2][..., None], mag[:2], mask[:2]), convmixer(host_imgs[:2]))

tensor([0.7670, 0.7739], grad_fn=<DivBackward0>)