In [12]:
from utils import *
from network import *

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [14]:
weights = load_pretrained_weights()

In [15]:
model = VggVox(weights=weights)
model = model.to(device)

In [16]:
criterion = ContrastiveLoss()
criterion = criterion.to(device)

In [17]:
loss_list = []
best_loss = torch.autograd.Variable(torch.tensor(np.inf)).float()

In [18]:
LEARNING_RATE = 1e-3
N_EPOCHS = 15
BATCH_SIZE = 64

In [19]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [9]:
# model, _, optimizer = load_saved_model("checkpoint_20181211-030043_0.014894404448568821.pth.tar", test=False)

In [20]:
voxceleb_dataset = VoxCelebDataset(PAIRS_FILE)
train_dataloader = DataLoader(voxceleb_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                              num_workers=4)
n_batches = int(len(voxceleb_dataset) / BATCH_SIZE)

print("training unique users", len(voxceleb_dataset.training_users))
print("training samples", len(voxceleb_dataset))
print("batches", int(len(voxceleb_dataset) / BATCH_SIZE))

training unique users 80
training samples 12800
batches 200


In [11]:
for epoch in range(1, N_EPOCHS+1):
    running_loss = torch.zeros(1)
    
    for i_batch, data in enumerate(train_dataloader, 1):
        mfcc1, mfcc2, label = data['spec1'], data['spec2'], data['label']
        mfcc1 = Variable(mfcc1.float(), requires_grad=True).to(device)
        mfcc2 = Variable(mfcc2.float(), requires_grad=True).to(device)
        label = Variable(label.float(), requires_grad=True).to(device)
                
        output1, output2 = model(mfcc1.float(), mfcc2.float())
        
        optimizer.zero_grad()

        loss = criterion(output1, output2, label.float())
        
#         assert mfcc1.dim() == mfcc2.dim() == 4        
#         assert output1.dim() == output2.dim() == 2
#         assert loss.requires_grad and output1.requires_grad and output2.requires_grad
#         assert loss.grad_fn is not None and output1.grad_fn is not None and output2.grad_fn is not None 
        
#         print("loss", loss, loss.requires_grad, loss.grad_fn)
#         print("output1", output1.shape, output1.requires_grad, output1.grad_fn, output1.device)
#         print("output2", output2.shape, output2.requires_grad, output2.grad_fn, output2.device)

        loss.backward()
            
#         assert mfcc1.requires_grad and mfcc2.requires_grad                
#         for name, param in model.named_parameters():
#             assert param.requires_grad and param.grad is not None, (name, param.requires_grad, param.grad)

        optimizer.step()

        loss_list.append(loss.item())
        running_loss += loss.item()
        if i_batch % int(n_batches/20) == 0:
            print("Epoch {}/{}  Batch {}/{} \nCurrent Batch Loss {}\n".format(epoch, N_EPOCHS, i_batch, n_batches, loss.item()))
        
    epoch_loss = running_loss / len(voxceleb_dataset)
    print("==> Epoch {}/{} Epoch Loss {}".format(epoch, N_EPOCHS, epoch_loss.item()))

    is_best = epoch_loss < best_loss
    if is_best:
        best_loss = epoch_loss
        
        save_checkpoint({'epoch': epoch,
                         'state_dict': model.state_dict(),
                         'optim_dict': optimizer.state_dict()},
                        loss=epoch_loss)
    else:
        print("### Epoch Loss did not improve\n")
    
#     plt.plot(loss_list[50:])
#     plt.show()

Epoch 1/10  Batch 20/400 
Current Batch Loss 0.6339170336723328

Epoch 1/10  Batch 40/400 
Current Batch Loss 0.8288595080375671

Epoch 1/10  Batch 60/400 
Current Batch Loss 0.7820997834205627

Epoch 1/10  Batch 80/400 
Current Batch Loss 0.9524711966514587

Epoch 1/10  Batch 100/400 
Current Batch Loss 0.7632725238800049

Epoch 1/10  Batch 120/400 
Current Batch Loss 0.457506388425827

Epoch 1/10  Batch 140/400 
Current Batch Loss 0.8156790137290955

Epoch 1/10  Batch 160/400 
Current Batch Loss 0.49526798725128174

Epoch 1/10  Batch 180/400 
Current Batch Loss 0.8071795701980591

Epoch 1/10  Batch 200/400 
Current Batch Loss 0.5311578512191772

Epoch 1/10  Batch 220/400 
Current Batch Loss 0.8559224009513855

Epoch 1/10  Batch 240/400 
Current Batch Loss 0.46079546213150024

Epoch 1/10  Batch 260/400 
Current Batch Loss 0.8578991293907166

Epoch 1/10  Batch 280/400 
Current Batch Loss 0.520740807056427

Epoch 1/10  Batch 300/400 
Current Batch Loss 0.5934733152389526

Epoch 1/10  Ba

In [None]:
plt.plot(loss_list[5000:])