In [12]:
import ast

import pandas as pd
from skorch import torch, dataset, NeuralNetClassifier
torch.manual_seed(0)

<torch._C.Generator at 0x7f151cd10790>

In [19]:
data = (
    pd.read_csv(
        "../data/pancluster/pancluster.full.tsv",
        sep="\t",
        usecols=["genome", "gene_family"],
        converters={"gene_family": ast.literal_eval}
    )
    .groupby("genome")
    .aggregate(list)["gene_family"]
    # Zero will be reserved for padding, so add 1 to every gene family
    .apply(lambda x: list(map(lambda y: [i+1 for i in y], x)))
)
data

genome
GCA_000149955.2    [[508, 472, 415, 419, 410, 334, 262, 211, 128,...
GCA_000222805.1    [[548, 390, 300, 277, 272, 315, 93, 44, 229, 2...
GCA_000259975.2    [[318, 289, 364, 194, 298, 113, 35, 59, 30, 85...
GCA_000260175.2    [[377, 372, 379, 381, 376, 361, 2104, 399, 400...
GCA_000260215.2    [[89, 91, 18, 21, 2, 337, 366, 13, 8, 66, 54],...
                                         ...                        
GCA_032878545.1    [[508, 472, 415, 419, 410, 334, 475, 262, 211,...
GCA_032991405.1    [[430, 414, 401, 281, 552, 275, 259, 87, 77, 2...
GCA_034509825.1    [[430, 414, 551, 401, 281, 275, 259, 87, 77, 2...
GCA_036785135.1    [[423, 392, 346, 347, 344, 342, 159, 65, 172, ...
GCA_038050555.1    [[430, 414, 401, 281, 552, 275, 259, 87, 77, 2...
Name: gene_family, Length: 242, dtype: object

In [22]:
target = pd.read_csv(
    "../accessions.tsv",
    sep="\t",
    index_col="genome"
)["fsp"].astype("category")

In [None]:
class PanclusterModule(torch.nn.Module):

    def __init__(
        self,
        unique_genes: int,
        embedding_dim: int,
        hidden_size: int,
        num_classes: int,
        device: torch.device = torch.device("cpu")
    ):
        super(PanclusterModule, self).__init__()

        self.device = device

        # The embedding allows us to treat gene families as categories
        self.embedding = torch.nn.Embedding(
            num_embeddings=unique_genes,
            embedding_dim=embedding_dim,
            padding_idx=0,
            device=self.device
        )

        # The LSTM learns the important features of each BGC
        self.lstm = torch.nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            batch_first=True,
            device=self.device
        )

        # Attention weights the importance of each BGC
        self.attention = torch.nn.Linear(
            in_features=hidden_size,
            out_features=1,
            device=self.device
        )

        # The final layer maps to the classes we wish to predict
        self.output = torch.nn.Linear(
            in_features=hidden_size,
            out_features=num_classes,
            device=self.device
        )
    
    def forward(self, batch: torch.Tensor):

        # The batch is a zero-padded 3D tensor of shape
        # (n_genomes, max_bgc_count, max_bgc_length)

        # # In order to train efficiently, we will extract all BGC tensors into
        # # a single list
        # bgcs = [bgc for genome in batch for bgc in genome]
        # bgc_lengths = torch.tensor(list(map(len, bgcs)))

        # # LSTMs require that input sequences are sorted by length in descending
        # # order. However, we must keep track of the indices used for sorting
        # # so that, at a later stage, when splitting tensors to batches, each set
        # # of BGCs is restored to its original position
        # sorted_indices = torch.argsort(bgc_lengths, descending=True)
        # restored_indices = torch.argsort(sorted_indices)
        # bgcs = [bgcs[index] for index in sorted_indices]
        # bgc_lengths = bgc_lengths[sorted_indices]

        # # Because each BGC contains different number of genes, we pad them with
        # # zeros, allowing us to create two dimensional tensors
        # # Example: [[1 2 3], [1 2]] -> [[1 2 3] [1 2 0]]
        # padded_bgcs = torch.nn.utils.rnn.pad_sequence(
        #     sequences=bgcs,
        #     batch_first=True,
        #     padding_value=0
        # )

        original_shape = batch.shape
        padded_bgcs = batch.flatten(0, 1)

        # We are now ready to pass the BGCs to the embedding layer
        embedded_bgcs: torch.Tensor = self.embedding(padded_bgcs)

        # Pack the BGC embeddings
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
            input=embedded_bgcs,
            lengths=bgc_lengths,
            batch_first=True,
            enforce_sorted=False
        )

        # Pass packed BGC embeddings through LSTM
        packed_output: torch.nn.utils.rnn.PackedSequence = (
            self.lstm(packed_input)[0]
        )

        # Unpack LSTM outputs
        padded_output = torch.nn.utils.rnn.pad_packed_sequence(
            sequence=packed_output,
            batch_first=True
        )[0]

        # Compute attention scores
        attention_scores: torch.Tensor = self.attention(padded_output)
        attention_scores.squeeze_(dim=-1)
        attention_weights = torch.softmax(attention_scores, dim=1).unsqueeze(-1)

        # Compute the weighted sum of LSTM outputs and restore original order
        context_vector = torch.sum(padded_output * attention_weights, dim=1)
        context_vector = context_vector[restored_indices]

        # Reshape back into batch format
        batch_sizes = list(map(len, batch))
        instance_vectors = torch.split(context_vector, batch_sizes, dim=0)

        # Aggregate BGCs per instance
        batch_output = torch.stack([
            instance.mean(dim=0) for instance in instance_vectors
        ])

        # Pass outputs to final fully connected layer
        final_output: torch.Tensor = self.output(batch_output)

        return final_output

In [26]:
device = torch.device("cpu")
unique_genes = data.explode().explode().max() + 1
embedding_dim = 5
hidden_size = 10
num_classes = 10

In [27]:
# Maximum number of BGCs in a genome
max_bgc_count = data.apply(len).max()

# Maximum number of genes in a BGC
max_bgc_length = data.apply(lambda x: max(map(len, x))).max()
max_bgc_count, max_bgc_length

(57, 35)

In [28]:
# Pad genomes with dummy BGCs so all genomes have the same number of BGCs
data = data.apply(lambda x: x + [[0]] * (max_bgc_count - len(x)))

In [29]:
# Pad each BGC with dummy genes such that all BGCs have the same length
data = data.apply(lambda x: [y + [0] * (max_bgc_length - len(y)) for y in x])

In [51]:
X = torch.tensor(data.to_list(), device=device)
X.shape

torch.Size([242, 57, 35])

In [53]:
embedding = torch.nn.Embedding(
    num_embeddings=unique_genes,
    embedding_dim=embedding_dim,
    padding_idx=0,
    device=device
)
embedding(X.flatten(0, 1)).shape

torch.Size([13794, 35, 5])

In [None]:
y = torch.tensor(target.cat.codes.tolist(), device=device)

In [None]:
net = NeuralNetClassifier(
    module=PanclusterModule,
    module__unique_genes=unique_genes,
    module__embedding_dim=embedding_dim,
    module__hidden_size=hidden_size,
    module__num_classes=num_classes,
    module__device=device,
    device="cuda",
    criterion=torch.nn.CrossEntropyLoss,
    train_split=dataset.ValidSplit(cv=4, stratified=True),
    max_epochs=10,
    lr=0.001,
    verbose=1
)

In [10]:
net.fit(X=X, y=y)

ValueError: Dataset does not have consistent lengths.