In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

from src.utils import inference, visualize_predictions
from src.data import RNADataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class SimpleRNAPredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=64, out_size=61, num_hidden=6):
        super().__init__()
        self.in_layer = nn.Linear(input_size, hidden_size)
        self.out_layer = nn.Linear(hidden_size, out_size)
        self.hidden_layers = []
        
        # for _ in range(num_hidden):
        #     self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))
        
        self.activation = nn.ReLU()

    def forward(self, x):
        out = self.activation(self.in_layer(x))
        # for layer in self.hidden_layers:
        #     out = self.activation(layer(out))
        out = nn.Softmax()(self.out_layer(out))
        return out

In [3]:
def load_model(checkpoint_path, device=None):
    """
    Load model and optimizer state from checkpoint.
    
    Args:
        checkpoint_path: Path to the checkpoint file
        device: torch device (if None, will use cuda if available)
    
    Returns:
        model: Loaded model
        optimizer: Loaded optimizer
        checkpoint_info: Dict containing epoch and loss
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Initialize model and move to device
    model = SimpleRNAPredictor().to(device)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Get checkpoint info
    checkpoint_info = {
        'epoch': checkpoint['epoch'],
        'loss': checkpoint['loss']
    }
    
    print(f"Loaded model from epoch {checkpoint_info['epoch']}")
    print(f"Best loss: {checkpoint_info['loss']:.4f}")
    
    return model, checkpoint_info

In [4]:
test_dataset = RNADataset.load("data/test_dataset.pkl")
test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
)

In [5]:
model, info = load_model('best_model.pt', device=torch.device('cuda:0'))
model

Loaded model from epoch 2
Best loss: 3.9046


  checkpoint = torch.load(checkpoint_path, map_location=device)


SimpleRNAPredictor(
  (in_layer): Linear(in_features=5, out_features=64, bias=True)
  (out_layer): Linear(in_features=64, out_features=61, bias=True)
  (activation): ReLU()
)

In [6]:
# inference(model, test_loader, device)

In [7]:
visualize_predictions(model, test_loader, device)


Visualizing 5 example predictions:

Example:
Sequence:  GGGAUUGUAGUUCAAUUGGUCAGAGCACCGCCCUGUCAAGGCGGAAGCUGCGGGUUCGAGCCCCGUCAGUCCCG
Predicted: (((.WW(W.(WW[..WW(()).(.().))())))())..(()((..())()((()))(.())))()).())))(
Ground Tr: (((((((..((((.........))))((((((.......))))))....(((((.......)))))))))))).

Example:
Sequence:  GGCCCCUUGGUCAAGCGGUUAAGACACCACCCUUUCACGGUGGUAACAGGGGUUCGAUUCCCCUAGGGGUCACCA
Predicted: (([[[[WW((W[..([(())..(.).)).))))))).)(()(()..).(((()))(.))))))).(((()).)).
Ground Tr: (((((((..((((........))))((((((.......))))))...(((((.......))))))))))))....

Example:
Sequence:  CACUUUAGCUGAGUUAGUGAUUGGCUAAAGCUUAUUAAUGCAUUAUUUGGAGAGACAAAAUGUCACUAAAUGCUGAACAAACUGCAACAAUCCUGGCUGAAUUCGGUCGUAGUGAA
Predicted: [.[WWW.([W(.(WW.(W(.))(())...())).))..)().)).)))((.(.(.)....)(.).))...)())(..)...))()..)..))))(())(..)))(())().()(..
Ground Tr: ......((((((........))))))...................(((((((...........)))))))..............................................

Example:
Sequence:  GGUUUGAAUG

  return self._call_impl(*args, **kwargs)



Example:
Sequence:  GCGGACAUAGCUUAGUUGGUAAAGCGCAACCUUGCCAAGGUUGAGACCGCGGGUUCGAGUCCCGUUGUCCGCUCUA
Predicted: ([((.[.W.([WW.(WW((W...()()..))))())..(())(.(.))()((()))(.())))())()))()))).
Ground Tr: (((((((..((((........))))((((((.......))))))....(((((.......))))))))))))....

Example:
Sequence:  GUGCUCGGUUUGUAGGCAGUGUCAUUAGCUGAUUGUACUGUGGUGGUUACAAUCACUAACUCCACUGCCAUCAAAACAAGGCAC
Predicted: (W([W[((WWW(W.(([.(W(W).)).())(.))().))()(()(()).)..)).))..)))).))()).))....)..(().)
Ground Tr: .(.((((.(((((.((((((((.(((((.((((((..............))))))))))).)))))))).))))).)))))...

Example:
Sequence:  GGCGGAUGUCAGCGGUUCGAGUCCGCUUAUCUCCA
Predicted: (([((.W(W[.([((WW[(.(W))())).))))).
Ground Tr: ..........(((((.......)))))........

Example:
Sequence:  AGGUGGGCCACGCCUCCCCACCGAGUGCGCGACCUAUCUGGAAGGAUAGGAGGA
Predicted: .((W((([[.[([[W[[[[.[[(.()()()(.))).)))((..((.).((.((.
Ground Tr: ..........((((((......))).)))....((((((....)))))).....

Example:
Sequence:  ACGCUGGCGGCGUGCUUAACACAUGCAAGUCGAACGAUGAAGCCGCUU

In [8]:
batch = next(iter(test_loader))

In [9]:
batch.keys()

dict_keys(['sequence', 'structure', 'attention_mask', 'length', 'raw_sequence', 'raw_structure'])

In [10]:
batch['structures'][10].detach().cpu().numpy().tolist()

KeyError: 'structures'