In [None]:
import logging
import os
import torch

# 设置环境中多张显卡可见
os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3,4,5,6,7"

# 检测显卡是否可用
if torch.cuda.is_available():
    logging.warning("cuda is available!")
    # 判断环境中多少个显卡。如果大于1就输出显卡个数
    if torch.cuda.device_count() > 1:
        logging.warning(f"find{torch.cuda.device_count()}GPUS!")
    else:
        logging.warning("it is only one GPU!")
else:
    logging.warning("cuda is not available,exit!")


In [None]:
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch import optim
from tqdm import tqdm
import logging
from torch.utils.tensorboard import SummaryWriter

from utils import *
from modules import UNet
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import StepLR

# logging 模块是 Python 内置的标准模块，主要用于输出运行日志
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%H:%M:%S")


class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, protein_high=560, protein_width=8, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.protein_high = protein_high
        self.protein_width = protein_width
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)   

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    # 对编码之后的蛋白质进行噪音处理
    def noise_proteins(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]     
        sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None, None, None]  
        ε = torch.randn_like(x)    
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * ε, ε

    # 采样间隔，生成 n 个值在 1 ~ noise_steps 之间的数字，代表每条序列的采样次数，即加多少步噪音
    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    # 采样，输出模型 model 和 要采样序列的数量 n
    def sample(self, model, n):
        logging.info(f"Sampling {n} new sequences......")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 1, self.protein_high, self.protein_width)).to(self.device)   
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):    
                t = (torch.ones(n) * i).long().to(self.device)      
                predicted_noise = model(x, t)      
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        return x

def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = UNet().to(device)
    model = torch.nn.DataParallel(UNet().to(device))
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    mse = nn.MSELoss()
    diffusion = Diffusion(protein_high=args.protein_high, protein_width=args.protein_width, device=device)
    logger = SummaryWriter(os.path.join("runs/", args.run_name))
    l = len(dataloader) 
    
    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, proteins in enumerate(pbar):
            proteins = proteins.to(device)           
            
            t = diffusion.sample_timesteps(proteins.shape[0]).to(device)
            x_t, noise = diffusion.noise_proteins(proteins, t)
            x_t = x_t.type(torch.FloatTensor)
            x_t = x_t.to(device)

            noise = noise.type(torch.FloatTensor)
            predicted_noise = model(x_t, t)
            predicted_noise = predicted_noise.type(torch.FloatTensor)
            
            loss = mse(noise, predicted_noise)
            loss = loss.type(torch.FloatTensor)
            loss = loss.to(device)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        sampled_proteins = diffusion.sample(model, n=proteins.shape[0])

        save_sequence(sampled_proteins, os.path.join("results/", args.run_name, f"{epoch}.fasta"))
        torch.save(model.module.state_dict(), os.path.join("models/", args.run_name, "{0}.pt".format(epoch)))
    print("over!")


def launch():
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_known_args()[0]
    args.run_name = "P450Diffusion"
    args.epochs = 500
    args.batch_size = 64
    args.protein_high = 560
    args.protein_width = 8
    args.dataset_path = "dataset/P450_All_Plant_Sequences_datasets.fasta"
    args.device = "cuda"
    args.lr = 2e-3
    train(args)


if __name__ == '__main__':
    launch()
