In [1]:
import torch
import torch.nn as nn
import os
import logging
import numpy as np
from tqdm.auto import tqdm
from Model import Patchify, ExtraMAEDecoder, ExtraMAEEncoder, ExtraMAE
from util import random_indexes, take_indexes
from einops import repeat, rearrange
import matplotlib.pyplot as plt
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration

from accelerate import DistributedDataParallelKwargs

from diffusion.util import instantiate_from_config
from diffusers.optimization import get_scheduler
from diffusion.util import instantiate_from_config, load_config

In [2]:
DIRS = ['checkpoints', 'logs', 'samples', "final", "model"]
logger = get_logger(__name__, log_level="INFO")

In [3]:
config = {
    "name": "MITBIH-ExtraMAE",
    "exp_dir": "/home/bejar/PycharmProjects/misiones/Series/Models/ExtraMAE/Train",
    "ema":{
        "inv_gamma": 1.0,
        "power": 0.75,
        "max_decay":0.9999
    },
    "model": {
        "in_channels": 1,
        "series_length": 192,
        "mask_percent": 0.75,
        "layers":8,
        "heads":4,
        "embed_dim":32,
        "patch_size":16
    },

    "projectconf": {
        "total_limit": 2
    },
    "accelerator": {
        "gradient_accumulation_steps": 1,
        "mixed_precision": "no",
        "log_with":"wandb"
    },
    "optimizer":{
        "beta1":0.95,
        "beta2":0.999,
        "weight_decay":1e-6,
        "epsilon":1e-08
    },
    "train": {
        "learning_rate": 1e-4,
        "lr_warmup_steps": 100,
        "epochs": 10,
        "checkpoint_freq": 2000,
        "checkpoint_epoch_freq": 2,
        "loss": "L2"
    },
    "samples": {
        "samples_freq": 25,
        "samples_num": 20,
        "samples_gen": 1000
    },

    "dataset": {
        "name": "EBHI",
        "nclasses":5,
        "train": {
            "class":"ai4ha.data.series.MITBIHDataLoader.MITBIHtrain",
            "params":{
                "filename": "/home/bejar/ssdstorage/MITBIH/mitbih_train.csv",
                "n_samples": 2000,
                "resamp": False,
                "oneD": True,
                "fixsize": 192,
                "normalize": False
            }
        },

        "test":{ 
            "class": "ai4ha.data.series.MITBIHDataLoader.MITBIHtest",
            "params":{
                "filename": "/home/bejar/ssdstorage/MITBIH/mitbih_test.csv",
                "n_samples": 100,
                "resamp": False,
                "oneD": True,
                "fixsize": 192,
                "normalize": False
            }
        },
        "dataloader":{
            "batch_size": 512,
            "num_workers": 6,
            "shuffle": True
        }
    },
    "time": 12
}

In [4]:
config = load_config("configs/ExtraMAE-KUHAR-L8-H4-E64-M75-LR4")

In [5]:
BASE_DIR = f"{config['exp_dir']}/logs/{config['name']}"

for dir in DIRS:
    os.makedirs(f"{BASE_DIR}/{dir}", exist_ok=True)   

accparams = config['accelerator']
# accparams["logging_dir"] = f"{BASE_DIR}/logs"
accparams["project_dir"] = BASE_DIR

if 'projectconf' in config:
    accparams['project_config'] = ProjectConfiguration(**config['projectconf'])

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=accparams['gradient_accumulation_steps'] > 1)
accelerator = Accelerator(**accparams, kwargs_handlers=[ddp_kwargs])



In [6]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)


02/22/2024 07:34:38 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no



In [7]:
model=  ExtraMAE(
        in_channels=config['model']['in_channels'],
        series_length=config['model']['series_length'],
        mask_percent=config['model']['mask_percent'],
        num_layers=config['model']['layers'],
        num_heads=config['model']['heads'],
        embed_dimension=config['model']['embed_dim'],
        patch_size=config['model']['patch_size']
)

torch.Size([20, 1, 64]) 300 15


In [8]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['train']['learning_rate'] * accelerator.num_processes,
    betas=(config['optimizer']['beta1'], config['optimizer']['beta2']),
    weight_decay=config['optimizer']['weight_decay'],
    eps=config['optimizer']['epsilon'],
)

In [9]:
train_data = instantiate_from_config(config['dataset']['train'])
test_data = instantiate_from_config(config['dataset']['test'])
train_dataloader = torch.utils.data.DataLoader(train_data, **config['dataset']["dataloader"])

(20750, 300)
(20750, 300)
(20750, 300)
(20750, 300)
(20750, 300)
(20750, 300)
X_train shape is (20750, 6, 300)
y_train shape is (20750,)
(20750, 300)
(20750, 300)
(20750, 300)
(20750, 300)
(20750, 300)
(20750, 300)
X_train shape is (20750, 6, 300)
y_train shape is (20750,)


In [10]:
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=config['train']['lr_warmup_steps'] * accparams['gradient_accumulation_steps'],
    num_training_steps=(len(train_dataloader) * config['train']['epochs']),
)

In [11]:
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, lr_scheduler)
model.to(accelerator.device);

In [12]:
num_update_steps_per_epoch = len(train_dataloader) // config['accelerator']['gradient_accumulation_steps']
model.train()
for epoch in range(config['train']['epochs']):
    progress_bar = tqdm(total=num_update_steps_per_epoch,
                        disable=not accelerator.is_local_main_process)
    progress_bar.set_description(f"Epoch {epoch}")
    mean_loss = 0
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(model):
            data, label = batch
            data = data.float().to(accelerator.device)
            pred, mask = model(data)
            batch_mask = data * mask
            pred_mask = pred * mask
            loss = torch.nn.functional.mse_loss(pred_mask,
                                                batch_mask,
                                                reduction="none")
            loss = loss.mean()
            mean_loss += loss.item()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_postfix({"loss": loss.item(), "mean_loss": mean_loss/(step+1)})


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

In [13]:
tdata = train_data[2][0] #test_data[1][0]

In [14]:
tdata.shape

(6, 300)

In [15]:
ttdata = torch.tensor(tdata).unsqueeze(0).to(accelerator.device).float()

In [16]:
ttdata.shape

torch.Size([1, 6, 300])

In [17]:
# a = torch.empty(ttdata.shape[2] // config['model']['patch_size']).uniform_(0, 1)

mask = torch.tensor([0,1,0,1,0,1,0,1,0,1,0,1])
mask

tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1])

In [18]:
# indexes = torch.nonzero(mask - 1, as_tuple=False)
# indexes_comp = torch.nonzero(mask, as_tuple=False)
# indexes_r = torch.cat((indexes[0], indexes_comp[0]))

In [19]:
# forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes_r],
#                                         axis=-1),
#                                 dtype=torch.long)
# forward_indexes

In [20]:
indexes = torch.nonzero(mask - 1, as_tuple=False)
indexes_comp = torch.nonzero(mask, as_tuple=False)
indexes_comp = torch.cat((indexes, indexes_comp), dim=0)
forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes_comp], axis=-1),
                                  dtype=torch.long)

In [21]:
# forward_indexes.shape, series.shape

In [22]:
series = model.encoder.patchify(ttdata)
series = rearrange(series, 'b c t -> t b c')
series += model.encoder.pos_embedding
series = take_indexes(series, indexes_comp.to(accelerator.device))
series = series[:forward_indexes.shape[0]]
series = model.encoder.tre_layer(series)
features = torch.cat([
    series,
    model.decoder.mask_token.expand(indexes_comp.shape[0] - series.shape[0],
                           series.shape[1], -1)
],
                     dim=0)
features = take_indexes(features, indexes_comp.to(accelerator.device))
features = features + model.decoder.pos_embedding_d
features = model.decoder.trd_layer(features)
features = model.decoder.head(features)
features = model.decoder.patch2img(features)

RuntimeError: The size of tensor a (12) must match the size of tensor b (20) at non-singleton dimension 0

In [None]:
plt.plot(features[0][0].cpu().detach().numpy(), c='r')
plt.plot(ttdata[0][0].cpu().detach().numpy())

In [None]:
res, mask = model(ttdata)

In [None]:
plt.plot(res[0][0].cpu().detach().numpy(), c='r')
plt.plot(ttdata[0][0].cpu().detach().numpy());

In [None]:
plt.plot(mask[0][0].cpu().detach().numpy(), c='r')