In [1]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_absolute_error
from scipy.stats import spearmanr, pearsonr
from datetime import datetime
from tqdm import tqdm
import h5py

In [2]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, heads, forward_expansion, dropout, max_length):
        super(TransformerDecoder, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, forward_expansion * d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(forward_expansion * d_model, d_model)
        )

        self.dropout = nn.Dropout(dropout)

        # Additional linear layer for output transformation
        self.output_transform = nn.Linear(d_model, 18)

        # Adaptive pooling layer to handle sequence length
        self.sequence_pooling = nn.AdaptiveAvgPool1d(1)

    def forward(self, x, enc_out=None, src_mask=None, trg_mask=None):
        attention_output, _ = self.attention(x, x, x, attn_mask=trg_mask)
        query = self.dropout(self.norm1(attention_output + x))

        out = self.feed_forward(query)
        out = self.dropout(self.norm2(out + query))

        out_transformed = self.output_transform(out)

        out_pooled = self.sequence_pooling(out_transformed.transpose(1, 2)).transpose(1, 2)

        return out_pooled

In [6]:
trained_basenji_transformer = TransformerDecoder(d_model=1536, heads=6, forward_expansion=2, dropout=0.2, max_length=896)
trained_filepath  = 'model_20231128_080512_7'
trained_basenji_transformer.load_state_dict(torch.load(trained_filepath))
trained_basenji_transformer.eval()

TransformerDecoder(
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)
  )
  (norm1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
  (feed_forward): Sequential(
    (0): Linear(in_features=1536, out_features=3072, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=3072, out_features=1536, bias=True)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (output_transform): Linear(in_features=1536, out_features=18, bias=True)
  (sequence_pooling): AdaptiveAvgPool1d(output_size=1)
)

In [7]:
predictions_savepath = f'/clusterfs/nilah/oberon/datasets/cs282a/inference/transformer_model_20231128_080512_7.h5'
f = h5py.File('/clusterfs/nilah/oberon/datasets/basenji/embeddings/embeddings.h5','r')
dset = f['embeddings']

In [8]:
with h5py.File(predictions_savepath,'w') as savefile:
    savefile.create_dataset(
        'single_bin',
        shape=(len(dset),1,18),
        chunks=(1,1,18),
        compression='gzip',
        compression_opts=9,
    )
    for i in tqdm(range(len(dset))):
        inputs = torch.Tensor(f['embeddings'][i]).reshape(1,896,1536)
        predictions = trained_basenji_transformer(inputs)
        savefile['single_bin'][i,:,:] = predictions.detach()

100%|██████████| 38171/38171 [1:04:37<00:00,  9.84it/s]
