In [1]:
# You should install minerva before running this script

#!pip install git+https://github.com/discovery-unicamp/Minerva-Dev.git
#!pip install kaleido

In [2]:
from minerva.models.nets.time_series.gans import GAN, TTSGAN_Encoder, TTSGAN_Discriminator, TTSGAN_Generator
from minerva.data.data_module_tools import RandomDataModule
import torch
import lightning as L

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
datamodule = RandomDataModule(
    data_shape = (6,60),
    num_classes = 6,
    num_train_samples = 8,
    num_val_samples = 8,
    num_test_samples= 8,
    batch_size = 16,
)

In [4]:
datamodule.setup("fit")
train_dataloader = datamodule.train_dataloader()

for x, y in train_dataloader:
    print(x.shape, y.shape)
    break

torch.Size([8, 6, 60]) torch.Size([8])


In [5]:
generator = TTSGAN_Generator(seq_len = 60, channels = 6)
discriminator = TTSGAN_Discriminator(seq_len = 60, channels = 6)

model = GAN(generator = generator,
            discriminator = discriminator, 
            loss_gen = torch.nn.MSELoss(),
            loss_dis = torch.nn.MSELoss(),
            )
model

GAN(
  (gen): TTSGAN_Generator(
    (l1): Linear(in_features=100, out_features=600, bias=True)
    (blocks): Gen_TransformerEncoder(
      (0): Gen_TransformerEncoderBlock(
        (0): ResidualAdd(
          (fn): Sequential(
            (0): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
            (1): MultiHeadAttention(
              (keys): Linear(in_features=10, out_features=10, bias=True)
              (queries): Linear(in_features=10, out_features=10, bias=True)
              (values): Linear(in_features=10, out_features=10, bias=True)
              (att_drop): Dropout(p=0.5, inplace=False)
              (projection): Linear(in_features=10, out_features=10, bias=True)
            )
            (2): Dropout(p=0.5, inplace=False)
          )
        )
        (1): ResidualAdd(
          (fn): Sequential(
            (0): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
            (1): FeedForwardBlock(
              (0): Linear(in_features=10, out_features=40, bias=T

In [6]:
'''from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint()
'''

'from lightning.pytorch.callbacks import ModelCheckpoint\n\ncheckpoint_callback = ModelCheckpoint()\n'

In [7]:
from lightning.pytorch.loggers.csv_logs import CSVLogger

logger = CSVLogger(save_dir='./training', name='ttsgan_50000steps_batch16', version = 3)

In [8]:
trainer = L.Trainer(accelerator='cpu', devices=1, callbacks=[], logger=logger, max_steps=50000)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [9]:
trainer.fit(model = model, datamodule = datamodule)

/usr/local/lib/python3.10/dist-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory ./training/ttsgan_50000steps_batch16/version_3 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory ./training/ttsgan_50000steps_batch16/version_3/checkpoints exists and is not empty.

  | Name     | Type                 | Params | Mode 
----------------------------------------------------------
0 | gen      | TTSGAN_Generator     | 65.3 K | train
1 | dis      | TTSGAN_Discriminator | 97.0 K | train
2 | loss_gen | MSELoss              | 0      | train
3 | loss_dis | MSELoss              | 0      | train
----------------------------------------------------------
162 K     Trainable params
0         Non-trainable params
162 K     Total params
0.649     Total estimated model params size (MB)
138       Modules in train

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.


                                                                           

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 28.98it/s, v_num=3]

Epoch 24999: 100%|██████████| 1/1 [00:00<00:00, 18.08it/s, v_num=3]

`Trainer.fit` stopped: `max_steps=50000` reached.


Epoch 24999: 100%|██████████| 1/1 [00:00<00:00, 12.08it/s, v_num=3]


In [10]:
ckp = torch.load(f='/workspaces/container-workspace/tts-gan/Notebooks/training/ttsgan_50000steps_batch16/all_checkpoints/4-epoch=12499-step=50000.ckpt')
generator = TTSGAN_Generator(seq_len = 60, channels = 6)
discriminator = TTSGAN_Discriminator(seq_len = 60, channels = 6)

test_gan = GAN(generator = generator,
            discriminator = discriminator, 
            loss_gen = torch.nn.MSELoss(),
            loss_dis = torch.nn.MSELoss(),
            )
test_gan.load_state_dict(ckp['state_dict'])
print(ckp['state_dict'])

OrderedDict([('gen.pos_embed', tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],


  ckp = torch.load(f='/workspaces/container-workspace/tts-gan/Notebooks/training/ttsgan_50000steps/all_checkpoints/4-epoch=12499-step=50000.ckpt')
