In [40]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [78]:
# from dataclasses import dataclass

# @dataclass
# class TrainingConfig:
#     # dataloader
#     data_path = "5utr_95.pt"
#     batch = 64
#     train_prop = 0.8
#     valid_prop = 0.2
#     shuffle = True

#     # model
#     input_length = 512
#     in_channels = 5
#     out_channels = 5
#     layers_per_block = 5
#     block_out_channels = [256, 256, 512, 512]  # 4 blocks each side
#     down_block_types = ["DownBlock1D", "DownBlock1D", "AttnDownBlock1D", "DownBlock1D"]
#     up_block_types = ["UpBlock1D", "AttnUpBlock1D", "UpBlock1D", "UpBlock1D"]
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     # train
#     scheduler = "DDPM"  # ['DDPM', 'DDIM']
#     num_train_timesteps = 1000
#     optimizer = "AdamW"  # ['AdamW', ...]
#     lr_warmup_steps = 500
#     epoch = 1000
#     lr = 1e-5

#     # log
#     save_model_epochs = 10
#     save_path = "./save_models"
#     # mixed_precision = 'fp16'
    
#     seed = 2024
    
# config = TrainingConfig()

from dataclasses import dataclass

@dataclass
class TrainingConfig:
    # dataloader
    data_path = "5utr_95_tmp.pt"
    batch = 64
    train_prop = 0.8
    valid_prop = 0.2
    shuffle = True

    # model
    input_length = 512
    in_channels = 5
    out_channels = 5
    layers_per_block = 5
    block_out_channels = [256, 256, 512, 512]  # 4 blocks each side
    down_block_types = ["DownBlock1D", "DownBlock1D", "AttnDownBlock1D", "DownBlock1D"]
    up_block_types = ["UpBlock1D", "AttnUpBlock1D", "UpBlock1D", "UpBlock1D"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # train
    scheduler = "DDPM"  # ['DDPM', 'DDIM']
    num_train_timesteps = 50
    optimizer = "AdamW"  # ['AdamW', ...]
    lr_warmup_steps = 10
    epoch = 3
    lr = 1e-5

    # log
    save_model_epochs = 1
    save_path = "./save_models"
    
    seed = 2024
    
config = TrainingConfig()


## load dataset

In [79]:
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset

loaded_sequences = torch.load(config.data_path, weights_only=True)

dataset = TensorDataset(loaded_sequences)

train_size = int(config.train_prop * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=config.batch, shuffle=config.shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch)

In [80]:
for train_batch in train_dataloader:
    print("Train batch shape:", train_batch[0].shape)
    break
    
for val_batch in val_dataloader:
    print("Validation batch shape:", val_batch[0].shape) 
    break

Train batch shape: torch.Size([64, 5, 512])
Validation batch shape: torch.Size([64, 5, 512])


## Create a UNet1DModel

In [81]:
from diffusers import UNet1DModel
import torch.nn as nn

class UNet1DWithSoftmax(nn.Module):
    def __init__(self):
        super(UNet1DWithSoftmax, self).__init__()
        self.unet = UNet1DModel(
            sample_size = config.input_length,  # the input length of data
            in_channels = config.in_channels,  # the one-hot encoded data
            out_channels = config.out_channels,  # reconstructed channel of data (also 5, cuz we need gain a sequence)
            layers_per_block = config.layers_per_block,  # how many ResNet layers to use per UNet block
            block_out_channels = config.block_out_channels,  # block output channels on each side
            down_block_types = config.down_block_types,
            up_block_types = config.up_block_types
        )
        self.softmax = nn.Softmax(dim=1)  # apply to channels (=>5, 512)

    def forward(self, x, timesteps, return_dict=False):
        x = self.unet(x, timesteps, return_dict=return_dict)[0]
        x = self.softmax(x)
        return x

model = UNet1DWithSoftmax().to(config.device)

## Create a Scheduler

In [82]:
from diffusers import DDPMScheduler, DDIMScheduler

if config.scheduler == "DDPM":
    scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps)  # clip_sample: True?
elif config.scheduler == "DDIM":
    scheduler = DDIMScheduler(num_train_timesteps=config.num_train_timesteps)

In [83]:
# # test the loss
# noise = torch.randn(1, 5, 512).to(device)  # (batch_size, in_channels, length)
# timesteps = torch.tensor([500]).to(device)
# with torch.no_grad():
#     noisy_output = model(noise, timesteps).sample

# print(noisy_output.shape)


## training preparation

In [84]:
import torch.optim as optim
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer = optimizer, 
    num_warmup_steps = config.lr_warmup_steps,
    num_training_steps = (len(train_dataloader) * config.epoch),
)
criterion = torch.nn.MSELoss()

In [85]:
# def evaluate(config, epoch, pipeline):
    

**Wandb setup**

In [86]:
# wandb
import wandb
# initialize wandb
wandb.require("core")
wandb.login()

wandb.init(
    project = "5utr-diffusion-tmp",
    config = config
)



VBox(children=(Label(value='0.043 MB of 0.043 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [87]:
wandb.watch(model, log='all', log_freq=10, log_graph=True)  # log weights of model every 1000 batches
wandb.config.system = {
    "monitor": True,
}
wandb.define_metric("global_step")  # every batch
wandb.define_metric("epoch")
wandb.define_metric("train_loss/batch", step_metric="global_step")
wandb.define_metric("lr/batch", step_metric="global_step")
wandb.define_metric("train_loss/epoch", step_metric="epoch")
wandb.define_metric("test_loss/epoch", step_metric="epoch")

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


<wandb.sdk.wandb_metric.Metric at 0x7f9bdbe8e130>

## Train the model
with wandb logs

In [92]:
import os
import datetime
from tqdm import tqdm

global_step = 0  # for wandb log
best_val_loss = float('inf')

for epoch in tqdm(range(config.epoch), desc="Epochs"):
    model.train()  # switch to train mode
    train_loss_list = []
    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}/{config.epoch}", leave=False):
        batch = batch[0]
        clean_data = batch.to(config.device)

        # sample noise to add to the sequences
        noise = torch.randn_like(batch).to(config.device)
        
        # sample a random timestep for each sequence
        timesteps = torch.randint(
            0, scheduler.num_train_timesteps, (batch.size(0),), device=config.device
        ).long()

        # add noise to the clean sequences
        noisy_seq = scheduler.add_noise(clean_data, noise, timesteps)

        # predict the noise added by scheduler
        noise_pred = model(noisy_seq, timesteps, return_dict=False)
        loss = criterion(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # logs
        train_loss_list.append(loss.mean().item())
        global_step += 1
        wandb.log({"train_loss/batch": loss.item(), "lr": lr_scheduler.get_last_lr()[0], "global_step": global_step})

    # end of one epoch (all data has been used to train model once)
    ## evalueation
    model.eval()
    with torch.no_grad():
        val_loss_list = []
        for val_batch in val_dataloader:
            val_batch = val_batch[0]
            clean_data = val_batch.to(config.device)
            val_noise = torch.randn_like(val_batch).to(config.device)
            val_timesteps = torch.randint(
                0, scheduler.num_train_timesteps, (val_batch.size(0),), device=config.device
            ).long()
            val_noisy_seq = scheduler.add_noise(clean_data, val_noise, val_timesteps)

            val_noise_pred = model(val_noisy_seq, val_timesteps, return_dict=False)
            val_loss = criterion(val_noise_pred, val_noise)

            val_loss_list.append(val_loss.mean().item())

        # log epoch results
        train_loss = sum(train_loss_list) / len(train_loss_list)
        val_loss = sum(val_loss_list) / len(val_loss_list)
        wandb.log({"train_loss/epoch": train_loss, "test_loss/epoch": val_loss, "epoch": epoch})

        # save the best model for now
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            TIME = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            torch.save(model.state_dict(), os.path.join(config.save_path, TIME+f"_5utr-diffusion_best_unet_model.pt"))
    
    # model log
    if epoch % config.save_model_epochs == 0 and epoch != 0:
        TIME = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
        pt_file = os.path.join(config.save_path, TIME+f"_5utr-diffusion_unet_epoch_{epoch}.pt")
        torch.save(model.state_dict(), pt_file)
        
TIME = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
torch.save(model.state_dict(), os.path.join(config.save_path, TIME+"_final_unet_model.pt"))
print(">>> Training finished.")

Epochs:   0%|                                                              | 0/3 [00:00<?, ?it/s]
Training Epoch 1/3:   0%|                                                 | 0/13 [00:00<?, ?it/s][A
Training Epoch 1/3:   8%|███▏                                     | 1/13 [00:00<00:01,  9.14it/s][A
Training Epoch 1/3:  15%|██████▎                                  | 2/13 [00:00<00:01,  9.58it/s][A
Training Epoch 1/3:  31%|████████████▌                            | 4/13 [00:00<00:00, 10.06it/s][A
Training Epoch 1/3:  38%|███████████████▊                         | 5/13 [00:00<00:01,  4.83it/s][A
Training Epoch 1/3:  54%|██████████████████████                   | 7/13 [00:01<00:00,  6.47it/s][A
Training Epoch 1/3:  69%|████████████████████████████▍            | 9/13 [00:01<00:00,  7.63it/s][A
Training Epoch 1/3:  85%|█████████████████████████████████▊      | 11/13 [00:01<00:00,  8.46it/s][A
Training Epoch 1/3: 100%|████████████████████████████████████████| 13/13 [00:01<00:00,  9.49it

>>> Training finished.



