In [1]:
from mlm import *
from fairseq import checkpoint_utils, options, tasks, utils, data, options
import os
import sys
sys.path.append("..")
from models.mlm import *
from criterions.mlm import *
from fairseq.data import data_utils
from Bio import SeqIO
from torch.utils.data import Dataset
import torch.nn as nn

In [2]:
# load pre-trained model, related task and args

mlm_pretrained_model = '../../../../rna-emb/RNAcentral--rna_mlm_base-MAXLEN1024-ckpt/checkpoint_best.pt'
arg_overrides = { "data": '../../../../../../data/RNAcentral_data/rnacentral_data-100_ftvocab_bin/' }

models, args, task = checkpoint_utils.load_model_ensemble_and_task(mlm_pretrained_model.split(os.pathsep), 
                                                                   arg_overrides=arg_overrides)
model = models[0]

In [15]:
from collections import OrderedDict


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
        nn.init.constant_(m.bias, 0.0)
    elif classname.find('BasicConv') != -1:   # for googlenet
        pass
    elif classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BatchNorm') != -1:
        if m.affine:
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight, std=0.001)
        if isinstance(m.bias, nn.Parameter):
            nn.init.constant_(m.bias, 0.0)
            

class rna_ss_resnet32(nn.Module):
    def __init__(self, model, depth=32):
        super().__init__()
        self.backbone = model.encoder
        self.depth = 32
        # 进入ResNet32之前：1、fc到128
        self.fc1 = nn.Linear(in_features=768, out_features=128)
        # nn.init.kaiming_normal_(self.fc1.weight, a=0, mode='fan_out')
        # nn.init.constant_(self.fc1.bias, 0.0)
        # nn.init.normal_(self.fc1.weight, std=0.001)
        # if isinstance(self.fc1.bias, nn.Parameter):
        #     nn.init.constant_(self.fc1.bias, 0.0)
        self.fc1.apply(weights_init_kaiming)
        self.fc1.apply(weights_init_classifier)
            
        # 进入ResNet32之前：2、conv到64    
        self.Conv2d_1 = nn.Conv2d(in_channels = 256, out_channels = 64, kernel_size = 1)
        nn.init.kaiming_normal_(self.Conv2d_1.weight, a=0, mode='fan_in')
        if self.Conv2d_1.bias is not None:
            nn.init.constant_(self.Conv2d_1.bias, 0.0)
        self.Conv2d_1.apply(weights_init_kaiming)
         
        # 定义ResNet32参数
        res_layers = []
        for i in range(self.depth):
            dilation = pow(2, (i % 3))
            res_layers.append(MyBasicResBlock(inplanes=64, planes=64, dilation=dilation))
        res_layers = nn.Sequential(*res_layers)

        # final_layer = nn.Conv2d(64, 2, kernel_size=3, padding=1)
        final_layer = nn.Conv2d(64, 1, kernel_size=3, padding=1)
        layers = OrderedDict()
        layers["resnet"] = res_layers
        layers["final"] = final_layer
        
        self.proj = nn.Sequential(layers)
        self.proj.apply(weights_init_kaiming)
        self.proj.apply(weights_init_classifier)
        
        
    def forward(self,x):
        _, x = self.encoder(x, segment_labels=None, masked_tokens=None,
                extra_only=False, masked_only=False)
        x = x['inner_states'][-1][1:-1,:,:].transpose(0, 1) # (T, B, C) -> (B, T, C)
        # Batch为1
        # x shape like [B,T,C]
        x = self.fc1(x) # -> [B,T,128]
        # x = x.squeeze()
        batch_size, seqlen, hiddendim = x.size()
        x = x.unsqueeze(2).expand(batch_size, seqlen, seqlen, hiddendim)
        x_T = x.permute(0,2,1,3)
        x_concat = torch.cat([x,x_T],dim=3) # -> [B,T,T,C*2]
        x = x_concat.permute(0,3,1,2) # -> [B,C*2,T,T]
        x = self.Conv2d_1(x)
        # ResNet32+output的conv处理
        x = self.proj(x)
        upper_triangular_x = torch.triu(x)
        lower_triangular_x = torch.triu(x,diagonal=1).permute(0,1,3,2)
        output = upper_triangular_x + lower_triangular_x
        # return shape like [B,1,L,L]
        return output
    
class MyBasicResBlock(nn.Module):
    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
    ) -> None:
        super(MyBasicResBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        # cjy commented
        #if dilation > 1:
        #    raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, bias = False)
        self.dropout = nn.Dropout(p=0.3)
        self.relu2 = nn.ReLU(inplace=True)
        # self.conv2 = conv3x3(planes, planes, dilation=dilation)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
        #self.bn2 = norm_layer(planes)
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out)
        out = self.dropout(out)
        out = self.relu2(out)
        out = self.conv2(out)

        out += identity
        return out
    

# class ResNet32_init(nn.Module):
#     def __init__(self, model, depth=32):
#         super().__init__()
#         self.encoder = model.encoder
#         self.depth = depth
#         self.dropout_value = 0.3
#         # before ResNet
#         self.fc1 = nn.Linear(in_features=768, out_features=128)
#         nn.init.kaiming_normal_(self.fc1.weight, a=0, mode='fan_out')
#         nn.init.constant_(self.fc1.bias, 0.0)
#         nn.init.normal_(self.fc1.weight, std=0.001)
#         if isinstance(self.fc1.bias, nn.Parameter):
#             nn.init.constant_(self.fc1.bias, 0.0)

#         self.Conv2d_1 = nn.Conv2d(in_channels = 256, out_channels = 64, kernel_size = 1)
#         nn.init.kaiming_normal_(self.Conv2d_1.weight, a=0, mode='fan_in')
#         if self.Conv2d_1.bias is not None:
#             nn.init.constant_(self.Conv2d_1.bias, 0.0)
#         # self.maxout = Maxout.apply

        
#         self.BN2_iter = nn.BatchNorm2d(num_features = 64)
#         nn.init.constant_(self.BN2_iter.weight, 1.0)
#         nn.init.constant_(self.BN2_iter.bias, 0.0)

#         self.RELU_iter = nn.ReLU(inplace=True)
#         self.Conv2d_1_iter = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, bias = False)
#         nn.init.kaiming_normal_(self.Conv2d_1_iter.weight, a=0, mode='fan_in')

#         self.Conv2d_2_iter = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, dilation = 1, padding = 1, bias = False)
#         nn.init.kaiming_normal_(self.Conv2d_2_iter.weight, a=0, mode='fan_in')
#         self.Conv2d_3_iter = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, dilation = 2, padding = 2, bias = False)
#         nn.init.kaiming_normal_(self.Conv2d_3_iter.weight, a=0, mode='fan_in')
#         self.Conv2d_4_iter = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, dilation = 4, padding = 4, bias = False)
#         nn.init.kaiming_normal_(self.Conv2d_4_iter.weight, a=0, mode='fan_in')
        
#         self.Dropout_iter = nn.Dropout(p = self.dropout_value)
        
#         self.BN2_2 = nn.BatchNorm2d(num_features = 64)
#         nn.init.constant_(self.BN2_2.weight, 1.0)
#         nn.init.constant_(self.BN2_2.bias, 0.0)

#         self.Conv2d_2 = nn.Conv2d(in_channels = 64, out_channels = 1, kernel_size = 3, padding = 1)
#         nn.init.kaiming_normal_(self.Conv2d_2.weight, a=0, mode='fan_in')
#         if self.Conv2d_2.bias is not None:
#             nn.init.constant_(self.Conv2d_2.bias, 0.0)
#     def forward(self, x):
#         _, x = self.encoder(x, segment_labels=None, masked_tokens=None,
#                 extra_only=False, masked_only=False)
#         x = x['inner_states'][-1][1:-1,:,:].transpose(0, 1) # (T, B, C) -> (B, T, C)
#         # Batch为1
#         # x shape like [B,T,C]
#         x = self.fc1(x) # -> [B,T,128]
#         # x = x.squeeze()
#         batch_size, seqlen, hiddendim = x.size()
#         x = x.unsqueeze(2).expand(batch_size, seqlen, seqlen, hiddendim)
#         x_T = x.permute(0,2,1,3)
#         x_concat = torch.cat([x,x_T],dim=3) # -> [B,T,T,C*2]
#         x = x_concat.permute(0,3,1,2) # -> [B,C*2,T,T]
#         res = self.Conv2d_1(x)
#         # res = self.maxout(x)
#         d_rate = 1
#         for i in range(self.depth):
#             x = self.BN2_iter(res)
#             x = self.RELU_iter(x)
#             x = self.Conv2d_1_iter(x)
#             x = self.Dropout_iter(x)
#             x = self.RELU_iter(x)
#             if d_rate == 1:
#                 x = self.Conv2d_2_iter(x)
#                 d_rate = 2
#             elif d_rate == 2:
#                 x = self.Conv2d_3_iter(x)
#                 d_rate = 4
#             else:
#                 x = self.Conv2d_4_iter(x)
#                 d_rate = 1
#             res = torch.add(x,res)
#         # x = self.BN2_2(res)
#         # x = self.ReLU(x)
#         x = self.Conv2d_2(x)
#         # x = x.squeeze(dim=1)
#         upper_triangular_x = torch.triu(x)
#         lower_triangular_x = torch.triu(x,diagonal=1).permute(0,1,3,2)
#         output = upper_triangular_x + lower_triangular_x
#         # return shape like [B,1,L,L]
#         return output

In [21]:
test_model = rna_ss_resnet32(model)

In [22]:
for _,param in enumerate(test_model.named_parameters()):
    print(param[0])
    print(param[1])
    print('----------------')

backbone.lm_output_learned_bias
Parameter containing:
tensor([-0.7256, -0.4961, -0.7461, -0.4971,  0.0071,  0.0061,  0.0099,  0.0056,
        -0.1783, -0.5601, -0.5991, -0.5112, -0.5112, -0.4587, -0.4902, -0.3201,
        -0.3760, -0.3784, -0.3970, -0.2549, -0.5869, -0.6099, -0.6064, -0.6089,
        -1.7812], requires_grad=True)
----------------
backbone.sentence_encoder.embed_tokens.weight
Parameter containing:
tensor([[-0.0240,  0.0423,  0.1136,  ..., -0.0842, -0.0219, -0.1686],
        [ 0.0050,  0.1285,  0.1570,  ..., -0.0750,  0.0038, -0.0883],
        [ 0.1110,  0.0334, -0.0156,  ..., -0.2527,  0.0149, -0.1470],
        ...,
        [ 0.0673,  0.0355,  0.0124,  ..., -0.0996, -0.0599, -0.1895],
        [ 0.0708,  0.0415,  0.0312,  ..., -0.1113, -0.0539, -0.2089],
        [ 0.0578,  0.0124,  0.0042,  ..., -0.1160, -0.1186, -0.1936]],
       requires_grad=True)
----------------
backbone.sentence_encoder.segment_embeddings.weight
Parameter containing:
tensor([[-0.0175, -0.0045, -0.0

In [24]:
200 %100

0

In [3]:
input = torch.Tensor([0,4,6,5,6,7,5,4,6,7,6,5,4,2]).long()

In [9]:
def get_para_num(model):
    lst = []
    for para in model.parameters():
        lst.append(para.nelement())
    print(f"total paras number: {sum(lst)}")

def get_trainable_para_num(model):
    lst = []
    for para in model.parameters():
        if para.requires_grad == True:
            lst.append(para.nelement())
    print(f"trainable paras number: {sum(lst)}")

In [13]:
get_para_num(test_model)

total paras number: 88739866


In [8]:
inner,_ = model.encoder.sentence_encoder(input.unsqueeze(0))
inner[-1].shape

torch.Size([14, 1, 768])

In [12]:
output = test_model(input.unsqueeze(0))

In [40]:
criterion_bce_weighted = torch.nn.BCEWithLogitsLoss(pos_weight = torch.Tensor([300]))

In [41]:
true = torch.randint(1,2,(1,12,12)).float()
true

tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])

In [47]:
print(test_model)

rna_ss_resnet32(
  (backbone): RNAMaskedLMEncoder(
    (sentence_encoder): TransformerSentenceEncoder(
      (dropout_module): FairseqDropout()
      (embed_tokens): Embedding(25, 768, padding_idx=1)
      (segment_embeddings): Embedding(2, 768)
      (embed_positions): SinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0): TransformerSentenceEncoderLayer(
          (dropout_module): FairseqDropout()
          (activation_dropout_module): FairseqDropout()
          (self_attn): MultiheadAttention(
            (dropout_module): FairseqDropout()
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_fea

In [20]:
seq_len = 6
vectors=np.empty([2,seq_len+2])
vectors[:,0] = task.dictionary.bos()
vectors[:,-1] = task.dictionary.eos()
seq_list = ['NNAUGC','AAAAAA']
# Iterate through UTRs and one-hot encode
for i,seq in enumerate(seq_list): 
    seq_index = []
    for j in range(len(seq)):
        if seq[j] == 'N':
            seq_index.append(task.dictionary.pad())
        else:
            seq_index.append(task.dictionary.index(seq[j]))    
    vectors[i,1:-1] = np.array(seq_index)
vectors

array([[0., 1., 1., 4., 6., 5., 7., 2.],
       [0., 4., 4., 4., 4., 4., 4., 2.]])

In [32]:
a = torch.empty((1))
b = torch.empty((10,1))
torch.cat((a,b),dim=0)

AttributeError: module 'torch' has no attribute 'vstack'

In [20]:
# Basic dataset for downstream tasks
class BasicDataset(Dataset):
    def __init__(self, fasta_file_path, label_file_path, task):
        self.rna_seqs, self.label = BasicDataset.load_file(fasta_file_path, label_file_path)
        self.encoded_rna_seqs = BasicDataset.encode(self.rna_seqs, task)
    
    def __getitem__(self,index):
        seq = self.encoded_rna_seqs[index]
        label = self.label
        return seq,label
    
    def __len__(self):
        return len(self.encoded_rna_seqs)
    
    @staticmethod
    def load_file(fasta_file_path, label_file_path):
        # load fasta file
        rna_seqs_list = []
        for record in SeqIO.parse(fasta_file_path, "fasta"):
            rna_seqs_list.append(str(record.seq))
        # load label file
        label = []
        with open(label_file_path) as fa:
            for line in fa:
                label.append(line)
        return rna_seqs_list, label
    
    @staticmethod
    def encode(rna_seqs_list, task):
        prepared_rna_seqs = []
        # add '<s>' token 
        for i in range(len(rna_seqs_list)):
            seq = list(rna_seqs_list[i])
            prepared_rna_seqs.append('<s> ' + ' '.join(seq))
        # encode rna seqs
        return data_utils.collate_tokens([task.source_dictionary.encode_line(seq, add_if_not_exist=False) for seq in prepared_rna_seqs], 
                                          pad_idx=task.source_dictionary.pad(), eos_idx=task.source_dictionary.eos()).long()

In [21]:
# Define models for downstream tasks
class DownstreamModel(nn.Module):
    def __init__(self, model, **kwargs):
        self.encoder = model.encoder
    
    def forward(self,x):
        # input x: encoded rna seqs
        # get embedding from encode(bert)
        _, x = self.encoder(x, segment_labels=None, masked_tokens=None,
                extra_only=False, masked_only=False)
        x = x['inner_states'][-1].transpose(0, 1) # (T, B, C) -> (B, T, C)
        
        # other downstream layer
        return x

In [22]:
# Fine-tune for downstream tasks

In [23]:
import argparse
import logging
from torch.utils.data import DataLoader, random_split
import wandb
from torch import optim
from tqdm import tqdm
from pathlib import Path

ModuleNotFoundError: No module named 'wandb'

In [None]:
def train_net(net,
              device,
              fasta_file_path,
              label_file_path,
              task,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 1e-5,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              ):
    # 1. Create dataset
    dataset = BasicDataset(fasta_file_path, label_file_path, task)

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='Downstream task 1', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint))

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=20)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(1, epochs+1):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (10 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in net.named_parameters():
                            tag = tag.replace('/', '.')
                            if not torch.isinf(value).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not torch.isinf(value.grad).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(net, val_loader, device)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        experiment.log({
                            'learning rate': optimizer.param_groups[0]['lr'],
                            'validation Dice': val_score,
                            'images': wandb.Image(images[0].cpu()),
                            'masks': {
                                'true': wandb.Image(true_masks[0].float().cpu()),
                                'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                            },
                            'step': global_step,
                            'epoch': epoch,
                            **histograms
                        })

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
            logging.info(f'Checkpoint {epoch} saved!')

In [2]:
# main (fine-tune)
def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--fasta_file_path',type=str, default="downstream_rna_fasta_path",help='File path of downstream rna fasta')
    parser.add_argument('--label_file_path',type=str, default="downstream_rna_label_path",help='File path of downstream rna label')

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # load pre-trained model, related task and args
    mlm_pretrained_model = '../../../../rna-emb/RNAcentral--rna_mlm_base-MAXLEN1024-ckpt/checkpoint_best.pt'
    arg_overrides = { "data": '../../../../../../data/RNAcentral_data/rnacentral_data-100_ftvocab_bin/' }

    models, pretrain_args, task = checkpoint_utils.load_model_ensemble_and_task(mlm_pretrained_model.split(os.pathsep), 
                                                                    arg_overrides=arg_overrides)
    model = models[0]
    
    
    net = DownstreamModel(model)

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # data file path
    rna_downstream_task_dir = 'downstream_tasks_rna_fasta_dir_path'
    fasta_file = rna_downstream_task_dir + 'rna_seqs_file1.fasta'
    label_file = rna_downstream_task_dir + 'rna_property.idx'
    try:
        train_net(net=net,
                  device=device,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  learning_rate=args.lr,
                  val_percent=args.val / 100,
                  fasta_file_path = args.fasta_file_path,
                  label_file_path = args.label_file_path,
                  task = task)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        raise




In [None]:
import os, sys, pdb, h5py
import os.path
import numpy as np
import torch
import torch.utils.data



def One_hot_to_index(x):
    X = np.zeros((x.shape[0],x.shape[1]+2))
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            if x[i,j].tolist() == [1,0,0,0]:
                X[i,j+1] = 4
            elif x[i,j].tolist() == [0,1,0,0]:
                X[i,j+1] = 7
            elif x[i,j].tolist() == [0,0,1,0]:
                X[i,j+1] = 5
            elif x[i,j].tolist() == [0,0,0,1]:
                X[i,j+1] = 6
    X[:,0] = 0
    X[:,-1] = 2
    return X    


class SeqicSHAPE(torch.utils.data.Dataset):
    def __init__(self, data_path, is_test=False, is_infer=False, use_structure=True):
        """data loader
        
        Args:
            data_path ([str]): h5 file path
            is_test (bool, optional): testset or not. Defaults to False.
        """
        if is_infer:
            self.dataset = self.__load_infer_data__(data_path, use_structure=use_structure)
            print("infer data: ", self.__len__()," use_structure: ", use_structure)
        else:
            dataset = h5py.File(data_path, 'r')
            X_train = np.array(dataset['X_train']).astype(np.float32)
            Y_train = np.array(dataset['Y_train']).astype(np.int32)
            X_test  = np.array(dataset['X_test']).astype(np.float32)
            Y_test  = np.array(dataset['Y_test']).astype(np.int32)
            if len(Y_train.shape) == 1:
                Y_train = np.expand_dims(Y_train, axis=1)
                Y_test  = np.expand_dims(Y_test, axis=1)
            X_train = np.expand_dims(X_train, axis=3).transpose([0, 3, 2, 1])
            X_test  = np.expand_dims(X_test,  axis=3).transpose([0, 3, 2, 1])

            train = {'inputs': X_train, 'targets': Y_train}
            test  = {'inputs': X_test,  'targets': Y_test}

            labels, nums = np.unique(Y_train,return_counts=True)
            print("train:", labels, nums)
            labels, nums = np.unique(Y_test,return_counts=True)
            print("test:", labels, nums)

            train = self.__prepare_data__(train)
            test  = self.__prepare_data__(test)

            if is_test:
                self.dataset = test
            else:
                self.dataset = train


    def __load_infer_data__(self, data_path, use_structure=True):
        # from prismnet.utils import datautils
        dataset = load_testset_txt(data_path, use_structure=use_structure, seq_length=101)
        return dataset
       
    
    def __prepare_data__(self, data):
        inputs    = data['inputs'][:,:,:,:4]
        structure = data['inputs'][:,:,:,4:]
        structure = np.expand_dims(structure[:,:,:,0], axis=3)
        inputs    = np.concatenate([inputs, structure], axis=3)
        data['inputs']  = inputs
        return data

    def __to_sequence__(self, x):
        x1 = np.zeros_like(x[0,:,:1])
        for i in range(x1.shape[0]):
            # import pdb; pdb.set_trace()
            x1[i] = np.argmax(x[0,i,:4])
            # import pdb; pdb.set_trace()
        return x1

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        x = self.dataset['inputs'][index]
        # x = self.__to_sequence__(x)
        y = self.dataset['targets'][index]
        return x, y


    def __len__(self):
        return len(self.dataset['inputs'])
    
    


def get_pretrain_index_from_one_hot(X_train, model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    # X_res shape like: [12002, 4, 101]
    X_res = X_train[:,:4,:]
    # X_train shape like: [12002, 101, 4]
    X_train = X_train[:,:4,:].transpose(0,2,1)
    len = 101
    num = X_train.shape[0]
    dim = 768
    # 定义
    X_train_embd = torch.empty((num,dim,len))
    # 把One_hot转换成预训练字典
    # X_train_index shape like [12002,103]
    X_train_index = One_hot_to_index(X_train)
    X_train_index = torch.from_numpy(X_train_index).long()
    # 1000个为一组转换
    number_tans = np.ceil(num/500)
    for i in range(int(number_tans)):
        print(i)
        input = X_train_index[i*500:(i+1)*500,:].to(device)
        with torch.no_grad():
            _, embed = model.encoder(input)
        # embed['inner_states'][-1] shape like [T, B, C]
        X_train_embd[i*500:(i+1)*500,:,:] = embed['inner_states'][-1][1:-1,:,:].transpose(0,1).transpose(1,2)
        del embed
        del input
        torch.cuda.empty_cache()
    X_train_embd = X_train_embd.numpy()
    X_train_mix = np.concatenate((X_res,X_train_embd),axis=1)
    print("Data_shape:",X_train_mix.shape)
    return X_train_mix

   
class SeqPretrain(torch.utils.data.Dataset):
    def __init__(self, data_path, model, is_test=False):
        """data loader
        
        Args:
            data_path ([str]): h5 file path
            is_test (bool, optional): testset or not. Defaults to False.
        """

        dataset = h5py.File(data_path, 'r')
        X_train = np.array(dataset['X_train']).astype(np.float32)
        Y_train = np.array(dataset['Y_train']).astype(np.int32)
        X_test  = np.array(dataset['X_test']).astype(np.float32)
        Y_test  = np.array(dataset['Y_test']).astype(np.int32)
        if len(Y_train.shape) == 1:
            Y_train = np.expand_dims(Y_train, axis=1)
            Y_test  = np.expand_dims(Y_test, axis=1)
        X_train = get_pretrain_index_from_one_hot(X_train, model)
        X_test = get_pretrain_index_from_one_hot(X_test, model)

        X_train = np.expand_dims(X_train, axis=3).transpose([0, 3, 2, 1])
        X_test  = np.expand_dims(X_test,  axis=3).transpose([0, 3, 2, 1])

        train = {'inputs': X_train, 'targets': Y_train}
        test  = {'inputs': X_test,  'targets': Y_test}

        labels, nums = np.unique(Y_train,return_counts=True)
        print("train:", labels, nums)
        labels, nums = np.unique(Y_test,return_counts=True)
        print("test:", labels, nums)

        train = self.__prepare_data__(train)
        test  = self.__prepare_data__(test)

        if is_test:
            self.dataset = test
        else:
            self.dataset = train     
    
    def __prepare_data__(self, data):
        inputs    = data['inputs'][:,:,:,:4]
        structure = data['inputs'][:,:,:,4:]
        structure = np.expand_dims(structure[:,:,:,0], axis=3)
        inputs    = np.concatenate([inputs, structure], axis=3)
        data['inputs']  = inputs
        return data

    def __to_sequence__(self, x):
        x1 = np.zeros_like(x[0,:,:1])
        for i in range(x1.shape[0]):
            # import pdb; pdb.set_trace()
            x1[i] = np.argmax(x[0,i,:4])
            # import pdb; pdb.set_trace()
        return x1

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        x = self.dataset['inputs'][index]
        # x = self.__to_sequence__(x)
        y = self.dataset['targets'][index]
        return x, y


    def __len__(self):
        return len(self.dataset['inputs'])    
    
    

