In [None]:
class SequenceTransformer(L.LightningModule):
    def __init__(self, lr=1e-3, d_model=64, nhead=4, num_layers=2):
        super().__init__()
        self.lr = lr
        
        # project one-hot (20) to d_model dimensions
        self.input_proj = nn.Linear(20, d_model)
        
        # transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=256,
            dropout=0.1,
            batch_first=True,
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # output projection: d_model â†’ 20 (one ddG per possible mutation)
        self.output_proj = nn.Linear(d_model, 20)
        
        self.loss_fn = nn.MSELoss()

    def forward(self, x):
        # x: (1, L, 20)
        out = self.input_proj(x)      # (1, L, d_model)
        out = self.transformer(out)   # (1, L, d_model)
        out = self.output_proj(out)   # (1, L, 20)
        return out

    def training_step(self, batch, batch_idx):
        x      = batch['sequence']   # (1, L, 20)
        mask   = batch['mask']       # (1, L, 20)
        target = batch['labels']     # (1, L, 20)
        pred   = self(x)             # (1, L, 20)
        loss   = self.loss_fn(pred[mask==1], target[mask==1])
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x      = batch['sequence']
        mask   = batch['mask']
        target = batch['labels']
        pred   = self(x)
        loss   = self.loss_fn(pred[mask==1], target[mask==1])
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

In [None]:
# use notebook 3's dataloader
dataset_train = SequenceData('data/mega_train.csv')
dataset_val   = SequenceData('data/mega_val.csv')
loader_train  = DataLoader(dataset_train, batch_size=1, shuffle=True)
loader_val    = DataLoader(dataset_val,   batch_size=1, shuffle=False)

model   = SequenceTransformer(lr=1e-3, d_model=64, nhead=4, num_layers=2)
trainer = L.Trainer(devices=1, max_epochs=20)
trainer.fit(model, loader_train, loader_val)