In [None]:
import torch
from torch import nn
import torchviz

import vit
import scipy.io
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from importlib import reload
from mat import mat

In [2]:
mat.keys()

dict_keys(['__header__', '__version__', '__globals__', 'all_images', 'all_dnas', 'all_labels', 'all_dnas_norepeat', 'all_dna_labels_norepeat', 'all_boldids', 'train_loc', 'val_seen_loc', 'val_unseen_loc', 'test_seen_loc', 'test_unseen_loc', 'species2genus', 'described_species_labels_train', 'described_species_labels_trainval', 'all_dna_features_cnn_original', 'all_image_features_resnet', 'all_image_features_gan', 'all_dna_features_cnn_new', 'all_string_dnas'])

In [3]:
species2genus = mat['species2genus']-1

# group species by genus

genus_species = dict()
max_specie_in_genus = 0
for genus_id, genus in pd.DataFrame(species2genus, columns=['genus']).groupby('genus'):
    specie_indices = genus.index.tolist()
    genus_species[genus_id] = specie_indices
    if len(specie_indices) > max_specie_in_genus:
        max_specie_in_genus = len(specie_indices)

print(len(genus_species))
print("Max specie in genus: ", max_specie_in_genus)


372
Max specie in genus:  23


In [None]:
# genus = lambda s: mat['species2genus'][s]-1
# genus(156)

array([62])

In [None]:
# group labels count
# pd.Series(mat['all_labels'].squeeze()).value_counts()


1038    759
977     540
1039    361
418     292
979     292
       ... 
529       4
578       2
530       2
533       2
156       1
Name: count, Length: 1050, dtype: int64

In [None]:
# from sklearn.decomposition import PCA

# pca = PCA(n_components=512)
# all_dna_features_cnn_pca = pca.fit_transform(mat['all_dna_features_cnn_new'])


In [None]:
# pca = PCA(n_components=512)
# all_image_features_gan_pca = pca.fit_transform(mat['all_image_features_gan'])

In [None]:
# mat['species2genus'].shape

(1050, 1)

In [None]:
# x = np.array(list(map(lambda s: len(s.strip()), mat['all_string_dnas'])))
# np.unique(x).size

120

In [None]:

all_dna_len = list(map(lambda s: len(s.strip()), mat['all_string_dnas']))
dna_str_len_mapping: dict[int,int] = {}

def dna_str_len_to_int(s_len):
    if s_len not in dna_str_len_mapping:
        dna_str_len_mapping[s_len] = len(dna_str_len_mapping)
    return dna_str_len_mapping[s_len]

# def all_dna_len_token():
#     return list(map(dna_str_len_to_int, all_dna_len))

all_dna_len_tokens = list(map(dna_str_len_to_int, all_dna_len))
print(list(zip(all_dna_len, all_dna_len_tokens)))

[(658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (666, 1), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (638, 2), (638, 2), (638, 2), (658, 0), (700, 3), (679, 4), (659, 5), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0), (658, 0),

In [7]:
deviceGPU = torch.device("cuda" if torch.cuda.is_available() else "cpu")
deviceCPU = torch.device("cpu")

device = deviceGPU
device

device(type='cuda')

In [None]:
from torch.utils.data import Dataset

class MultiModalDataset:
    def __init__(self, dna_strings, images, labels, dna_str_len_mapping, species2genus, genus_species, img_processor, dna_tokenizer, max_length=1600):
        self.images = images
        self.dna_strings = dna_strings
        self.labels = labels
        self.img_processor = img_processor
        self.dna_tokenizer = dna_tokenizer
        self.dna_str_len_mapping = dna_str_len_mapping
        self.species2genus = species2genus
        self.max_length = max_length
        self.genus_species = genus_species

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # ===== Image Processing =====
        image = self.images[idx].permute(1, 2, 0).cpu().numpy()  # CHW -> HWC
        if image.max() <= 1.0:
            image = (image * 255).astype(np.uint8)
        

        image_encoding = self.img_processor(images=image, return_tensors="pt")
        pixel_values = image_encoding['pixel_values'].squeeze(0)

        # ===== DNA Processing =====
        dna_sequence = self.dna_strings[idx].strip()
        dna_len_token = self.dna_str_len_mapping.get(len(dna_sequence), 0)  # default 0 if not found

        dna_encoding = self.dna_tokenizer(
            dna_sequence,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        input_ids = dna_encoding['input_ids'].squeeze(0)
        attention_mask = dna_encoding['attention_mask'].squeeze(0)

        # ===== Label & Genus =====
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        genus = torch.tensor(self.species2genus[idx], dtype=torch.long)

        return {
            'pixel_values': pixel_values,
            'dna_len_token': torch.tensor(dna_len_token, dtype=torch.long),
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label,
            'genus': genus
        }

In [None]:
import vit
reload(vit)
from vit import get_processor_encoder, get_img_embedding
img_processor, img_encoder = get_processor_encoder("./vit-finetuned7-final", device)
get_img_embedding(mat['all_images'][:2], img_processor, img_encoder, device).shape

Some weights of ViTModel were not initialized from the model checkpoint at ./vit-finetuned7-final and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([10, 768])

In [None]:
import dnaencoder
reload(dnaencoder)
from dnaencoder import get_tokenizer_encoder, get_dna_embedding
dna_tokenizer, dna_encoder = get_tokenizer_encoder("./dnaencoder-finetuned1755100772-final", device)
get_dna_embedding(mat['all_string_dnas'][:2], dna_tokenizer, dna_encoder).shape

<class 'numpy.ndarray'>


torch.Size([10, 512])

In [None]:
dataset = MultiModalDataset(mat['all_string_dnas'], mat['all_images'], mat['all_labels'], all_dna_len_tokens, species2genus, genus_species, img_processor, dna_tokenizer)

In [None]:
import models
reload(models)
from models import AttentionFusion, GenusClassifier



fusion = AttentionFusion(
    dna_dim=512,
    img_dim=768,
    dna_len_dim=16,
).to(device)
genus_predictor = GenusClassifier(mat['species2genus'], fusion).to(device)

genus_predictor.to(deviceGPU)
x = genus_predictor(dataset[0:2]['dna_len_token'].unsqueeze(0).to(deviceGPU), 
                get_dna_embedding(dataset[0:2]['dna_len_token'].unsqueeze(0).to(deviceGPU), 
                                  get_img_embedding(dataset[0:2]['pixel_values'].unsqueeze(0).to(deviceGPU), img_processor, img_encoder)))
print(x.shape, x)
# genus_predictor.fit(all_dna_len_tokens,
#             all_dna_features_cnn_pca, 
#               all_image_features_gan_pca, 
#               mat['all_labels'].squeeze(), 
#               mat['val_seen_loc'].squeeze(), 
#               mat['train_loc'].squeeze(), 
#               1000, 
#               lr=0.005,
#               eval_frequency=2,
#               device=deviceGPU)

SyntaxError: invalid syntax (3972358454.py, line 9)

In [None]:
genus_predictor.load_state_dict(torch.load('output/Tue Aug 12 20:36:21 2025_best_genus_predictor.pt'))
genus_predictor.to(deviceCPU)

GenusPredictor(
  (fusion_encoder): AttentionFusion(
    (dna_len_emb): Embedding(120, 16)
    (proj_dna): Linear(in_features=512, out_features=496, bias=True)
    (proj_img): Linear(in_features=512, out_features=512, bias=True)
    (ffn): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
    )
  )
  (decoder): Decoder(
    (ffn): Sequential(
      (0): Linear(in_features=512, out_features=744, bias=True)
      (1): Sigmoid()
      (2): Linear(in_features=744, out_features=372, bias=True)
    )
  )
  (criterion): CrossEntropyLoss()
)

In [None]:
import models
reload(models)
from models import AttentionFusion, MainClassifier, Decoder, GenusClassifier

specie_predictor = MainClassifier(mat['species2genus'],genus_species, genus_predictor).to(deviceCPU)
specie_predictor.fit(
    all_dna_len_tokens,
    all_dna_features_cnn_pca, 
    all_image_features_gan_pca, 
    mat['all_labels'].squeeze(), 
    mat['val_seen_loc'].squeeze(), 
    mat['train_loc'].squeeze(), 
    200, 
    lr=0.005,
    eval_frequency=10,
    freeze_genus=True,
    teacher_force=True,
    device=deviceCPU)

In [26]:
all_dna_features_cnn_pca.shape

(32424, 768)

In [None]:
np.unique(mat['all_labels'])

array([   1,    2,    3, ..., 1048, 1049, 1050], shape=(1050,))

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DNAImageDecoder(nn.Module):
    def __init__(self, N_dna, N_image, d_model=128, num_heads=4, num_classes=10):
        super().__init__()
        
        # Project DNA and image embeddings into same space
        self.dna_proj = nn.Linear(N_dna, d_model)
        self.img_proj = nn.Linear(N_image, d_model)
        
        # Self-attention mechanism
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        
        # Feed-forward layer after attention
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, num_classes)
        )
        
    def forward(self, dna_emb, img_emb):
        """
        dna_emb: (batch_size, N_dna)
        img_emb: (batch_size, N_image)
        """
        # Project to same dimension
        dna_token = self.dna_proj(dna_emb).unsqueeze(1)  # (batch, 1, d_model)
        img_token = self.img_proj(img_emb).unsqueeze(1)  # (batch, 1, d_model)
        
        # Sequence: [DNA, Image]
        seq = torch.cat([dna_token, img_token], dim=1)  # (batch, 2, d_model)
        
        # Self-attention
        attn_out, _ = self.attn(seq, seq, seq)  # (batch, 2, d_model)
        
        # Pooling — use first token (DNA) or mean-pool
        pooled = attn_out.mean(dim=1)  # (batch, d_model)
        
        # Classification
        logits = self.ffn(pooled)  # (batch, num_classes)
        return logits


In [None]:
# ma = nn.MultiheadAttention(768, 4)
# attn_output, attn_output_weights = (ma(torch.rand(1, 1, 768), torch.rand(1, 1, 768), torch.rand(1, 1, 768)))


In [None]:
# attn_output_weights

tensor([[[1.]]], grad_fn=<MeanBackward1>)

In [None]:
# dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
# dna_emb = dnaencoder.model(**dnaencoder.tokenizer(dna, return_tensors='pt',))


odict_keys(['last_hidden_state', 'pooler_output'])

In [None]:
# from transformers import AutoImageProcessor, AutoModel
# from PIL import Image
# import torch

# # Load pretrained ViT model
# model_name = "google/vit-base-patch16-224"
# imageprocessor = AutoImageProcessor.from_pretrained(model_name)
# vitmodel = AutoModel.from_pretrained(model_name)
# vitmodel.eval()

# # Put model on GPU if available
# deviceCPU = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# vitmodel.to(deviceCPU)

# def get_vit_embedding(image_np, model_name="google/vit-base-patch16-224", device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
#     inputs = imageprocessor(images=image_np, return_tensors="pt")
#     inputs = {k: v.to(device) for k, v in inputs.items()}

#     with torch.no_grad():
#         outputs = vitmodel(**inputs)
    
#     return outputs.last_hidden_state[:, 0]  # CLS token


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
