### Running Inference on Different Fine-Tuning Models

This notebook is intended for CS 182/282A project reviewers to verify that the models run. To re-run experiments to see how the models were trained, please take a look at the subdirectories after.

In [1]:
import h5py
import torch
from torch import nn
from collections import OrderedDict

#### 0. Test Embeddings

In [2]:
test_path = f'./sample_data/test_chunk_X1.h5'
f = h5py.File(test_path, 'r')
dset = f['embeddings']

#### 1. Linear Transformation

In [3]:
class LinearTransform(nn.Module):
    """Takes in input (B, 1536, 896) and outputs predictions (B, 18, 896)."""

    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Conv1d(in_channels=1536, out_channels=18, kernel_size=1)
        nn.init.kaiming_normal_(self.conv_layer.weight, nonlinearity='relu')
        nn.init.zeros_(self.conv_layer.bias)
        self.activation = nn.Softplus()

    def forward(self, x):
        out = None
        out = self.activation(self.conv_layer(x))
        return out

In [4]:
trained_probe = LinearTransform()
trained_probe.load_state_dict(torch.load('../cs282a_linear-probing/first_full_run.pth', map_location=torch.device('cpu')))
trained_probe.eval()

LinearTransform(
  (conv_layer): Conv1d(1536, 18, kernel_size=(1,), stride=(1,))
  (activation): Softplus(beta=1, threshold=20)
)

In [5]:
for i in range(len(dset)):
    inputs = torch.Tensor(dset[i])
    predictions = trained_probe(inputs.transpose(0,1))

#### 2. 1D CNN + Perceptron

In [6]:
class MLPModel(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(MLPModel, self).__init__(*args, **kwargs)

        self.layers = nn.Sequential(OrderedDict([
            ('conv1x1', nn.Conv1d(1536, 500, 1)),
            ('gelu1', nn.GELU()),
            ('flatten', nn.Flatten()),
            ('fc1', nn.Linear(448000, 18))
        ]))
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [7]:
mlp_model = MLPModel()
mlp_model.load_state_dict(torch.load('../cs282a_conv1d_perceptron/model_20231128_063541_2'))
mlp_model.eval()

MLPModel(
  (layers): Sequential(
    (conv1x1): Conv1d(1536, 500, kernel_size=(1,), stride=(1,))
    (gelu1): GELU(approximate='none')
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=448000, out_features=18, bias=True)
  )
)

In [8]:
for i in range(len(dset)):
    inputs = torch.Tensor(dset[i]).reshape(1,896,1536)
    predictions = mlp_model(inputs.transpose(1,2))

#### 3. 1d CNN + Max Pooling + Perceptron

In [9]:
class MLPModelPooling(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(MLPModelPooling, self).__init__(*args, **kwargs)

        self.layers = nn.Sequential(OrderedDict([
            ('conv1x1', nn.Conv1d(1536, 500, 1)),
            ('gelu1', nn.GELU()),
            ('maxpool1', nn.MaxPool1d(896)),
            ('flatten', nn.Flatten()),
            ('fc1', nn.Linear(500, 18))
        ]))
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [10]:
pool_model = MLPModelPooling()
pool_model.load_state_dict(torch.load('../cs282a_perceptron-maxpool/model_20231128_072156_3'))
pool_model.eval()

MLPModelPooling(
  (layers): Sequential(
    (conv1x1): Conv1d(1536, 500, kernel_size=(1,), stride=(1,))
    (gelu1): GELU(approximate='none')
    (maxpool1): MaxPool1d(kernel_size=896, stride=896, padding=0, dilation=1, ceil_mode=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=500, out_features=18, bias=True)
  )
)

In [11]:
for i in range(len(dset)):
    inputs = torch.Tensor(dset[i]).reshape(1,896,1536)
    predictions = pool_model(inputs.transpose(1,2))

#### 4. Transformer

In [12]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, heads, forward_expansion, dropout, max_length):
        super(TransformerDecoder, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, forward_expansion * d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(forward_expansion * d_model, d_model)
        )

        self.dropout = nn.Dropout(dropout)

        # Additional linear layer for output transformation
        self.output_transform = nn.Linear(d_model, 18)

        # Adaptive pooling layer to handle sequence length
        self.sequence_pooling = nn.AdaptiveAvgPool1d(1)

    def forward(self, x, enc_out=None, src_mask=None, trg_mask=None):
        attention_output, _ = self.attention(x, x, x, attn_mask=trg_mask)
        query = self.dropout(self.norm1(attention_output + x))

        out = self.feed_forward(query)
        out = self.dropout(self.norm2(out + query))

        out_transformed = self.output_transform(out)

        out_pooled = self.sequence_pooling(out_transformed.transpose(1, 2)).transpose(1, 2)

        return out_pooled

In [13]:
trained_basenji_transformer = TransformerDecoder(d_model=1536, heads=6, forward_expansion=2, dropout=0.2, max_length=896)
trained_filepath  = '../cs282a_self-attention/model_20231128_080512_3'
trained_basenji_transformer.load_state_dict(torch.load(trained_filepath))
trained_basenji_transformer.eval()

TransformerDecoder(
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)
  )
  (norm1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
  (feed_forward): Sequential(
    (0): Linear(in_features=1536, out_features=3072, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=3072, out_features=1536, bias=True)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (output_transform): Linear(in_features=1536, out_features=18, bias=True)
  (sequence_pooling): AdaptiveAvgPool1d(output_size=1)
)

In [14]:
for i in range(len(dset)):
    inputs = torch.Tensor(dset[i]).reshape(1,896,1536)
    predictions = trained_basenji_transformer(inputs)