In [1]:
from utils.utils import *
from utils.specgan import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import pickle

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Setup

In [12]:
GPU=False
if not GPU:
    device= 'cpu'
# Data Params
DATA_PATH='./data'
AUDIO_LENGTH = 16384 #[16384, 32768, 65536] 
SAMPLING_RATE = 16000
NORMALIZE_AUDIO = False 
CHANNELS = 1

#Model params
LATENT_NOISE_DIM = 100
MODEL_CAPACITY=64
LAMBDA_GP = 10

#Training params
TRAIN_DISCRIM = 5 # how many times to train the discriminator for one generator step
EPOCHS = 30 #was 500, just tried to speed it up a bit
BATCH_SIZE=26
LR_GEN = 1e-4
LR_DISC = 1e-4 # alternative is bigger lr instead of high TRAIN_DISCRIM
BETA1 = 0.5
BETA2 = 0.9


# Dataset and Dataloader

#load into vram
#train_set = AudioDataset_ram(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,device=device,spectrogram=True)

#load into cpu ram
train_set = AudioDataset_ram(DATA_PATH,sample_rate=SAMPLING_RATE,number_samples=AUDIO_LENGTH,extension='wav',std=NORMALIZE_AUDIO,device='cpu',spectrogram=True)

print(train_set.__len__())


loading sample 0: 100%|██████████| 3/3 [00:02<00:00,  1.49sample/s]

3





In [13]:
print(train_set.mel_mean)
print(train_set.mel_std)

0.61346169079953
3.9526220991308127


# loading the model

In [14]:
train_loader = DataLoader(dataset=train_set,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4)
#generator and discriminator
spec_gen = SpecGenerator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)
spec_disc = SpecDiscriminator(d=MODEL_CAPACITY, c=CHANNELS ,inplace=True).to(device)

#random weights init
initialize_weights(spec_gen)
initialize_weights(spec_disc)

spec_gen.train()
spec_disc.train()

#Adam optim for both generator iand discriminator
optimizer_gen = optim.Adam(spec_gen.parameters(), lr=LR_GEN, betas=(BETA1, BETA2))
optimizer_disc = optim.Adam(spec_disc.parameters(), lr=LR_DISC, betas=(BETA1, BETA2))



start=-1 #to resume from saved state
epoch_start=-1 #to resume from saved state
if start>0:
    spec_disc.load_state_dict(torch.load('./save/specdisc/spec_'+str(epoch_start)+'_'+str(start)+'.pt'))
    spec_gen.load_state_dict(torch.load('./save/specgen/gen_'+str(epoch_start)+'_'+str(start)+'.pt'))

# Training loop

make sure the 4 saving folders are created, or else, it will crash

-save/spechist

-save/specfake

-save/specdisc

-save/specgen

In [15]:
#training
import pickle
step = start+1 # for restart from saved weights
epoch_start+=1
hist=[]
for epoch in range(epoch_start,EPOCHS):
    with tqdm(train_loader, unit="batch") as tepoch: 
        for batch_id, real_audio in enumerate(tepoch):  
            tepoch.set_description(f"Epoch {epoch}")
            real_audio = real_audio.to(device)
            
            #Train Discriminator 
            for train_step in range(TRAIN_DISCRIM):
                noise = torch.randn(real_audio.shape[0], LATENT_NOISE_DIM).to(device)
                #print(noise.shape)
                fake_audio = spec_gen(noise)
                disc_real = spec_disc(real_audio).reshape(-1)
                disc_fake = spec_disc(fake_audio).reshape(-1)
                loss_disc = wasserstein_loss(spec_disc, real_audio, fake_audio,device,LAMBDA = LAMBDA_GP,spec_gan=True)
                spec_disc.zero_grad()
                loss_disc.backward(retain_graph=True)
                optimizer_disc.step()

            # Train the generator!
            all_wasserstein = spec_disc(fake_audio).reshape(-1)
            loss = -torch.mean(all_wasserstein)
            spec_gen.zero_grad()
            loss.backward()
            optimizer_gen.step()
            step += 1
            # Print progress, save stats, and save model
            hist.append([loss.item(),loss_disc.item()])
            if batch_id % 5 == 0 and batch_id > 0:
                tepoch.set_postfix(gen_loss=loss.item(), disc_loss=loss_disc.item())

    if epoch % 5 == 0:
        with open('./save/spechist/hist_'+str(step)+'_'+str(epoch)+'_'+str(batch_id)+'.pkl', 'wb') as f:
            pickle.dump(hist, f)
        torch.save(spec_gen.state_dict(), './save/specgen/gen_'+str(epoch)+'_'+str(step)+'.pt')
        torch.save(spec_disc.state_dict(), './save/specdisc/spec_'+str(epoch)+'_'+str(step)+'.pt')
        with torch.no_grad():
            fake = spec_gen(noise)
            torch.save(fake, './save/specfake/fake_'+str(epoch)+'_'+str(step)+'.pt')
        if step>30000:
            break

Epoch 0: 100%|██████████| 1/1 [00:06<00:00,  6.01s/batch]
Epoch 1: 100%|██████████| 1/1 [00:06<00:00,  6.19s/batch]
Epoch 2: 100%|██████████| 1/1 [00:05<00:00,  5.52s/batch]
Epoch 3: 100%|██████████| 1/1 [00:06<00:00,  6.23s/batch]
Epoch 4: 100%|██████████| 1/1 [00:05<00:00,  5.84s/batch]
Epoch 5: 100%|██████████| 1/1 [00:06<00:00,  6.61s/batch]
Epoch 6: 100%|██████████| 1/1 [00:06<00:00,  6.40s/batch]
Epoch 7: 100%|██████████| 1/1 [00:06<00:00,  6.65s/batch]
Epoch 8: 100%|██████████| 1/1 [00:07<00:00,  7.19s/batch]
Epoch 9: 100%|██████████| 1/1 [00:06<00:00,  6.66s/batch]
Epoch 10: 100%|██████████| 1/1 [00:06<00:00,  6.33s/batch]
Epoch 11: 100%|██████████| 1/1 [00:06<00:00,  6.31s/batch]
Epoch 12: 100%|██████████| 1/1 [00:06<00:00,  6.17s/batch]
Epoch 13: 100%|██████████| 1/1 [00:06<00:00,  6.25s/batch]
Epoch 14: 100%|██████████| 1/1 [00:06<00:00,  6.66s/batch]
Epoch 15: 100%|██████████| 1/1 [00:06<00:00,  6.64s/batch]
Epoch 16: 100%|██████████| 1/1 [00:06<00:00,  6.30s/batch]
Epoch 1

In [11]:
5%5

0