In [1]:
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset

from unet_dac import UnetDAC

In [16]:
from config import NUM_MICS, ANGLE_RES


L = 96
K = 256
INPUT_LEN = 64
model = UnetDAC(L=L, K=K, M=NUM_MICS)

In [17]:
inputs = torch.Tensor(np.random.random((INPUT_LEN, 2*(NUM_MICS - 1), K, L)))
outputs = torch.Tensor(np.random.randint((INPUT_LEN, NUM_MICS)))
print(f'Input shape: {inputs.shape}')
res = model(inputs)
print(f"res shape: {res.shape}")

Input shape: torch.Size([64, 14, 256, 96])
res shape: torch.Size([64, 13, 256, 96])


In [21]:
NUM_TRAIN_EPOCHS = 100
lr: float = 0.001
epochs: int = 100
early_stopping: int = 3
mininbatch_size: int = 64
data = torch.load('data.pt')

dataset = TensorDataset(data['samples'], data['ref_stft'], data['target'])
train_loader = DataLoader(dataset, batch_size=mininbatch_size, shuffle=True)



dataloader = torch.utils.data.DataLoader(dataset, batch_size=mininbatch_size, shuffle=True, num_workers=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
running_loss = 0.0
for epoch in range(epochs):
    print(f'--- epoch {epoch} ---')
    for i, minibatch in tqdm(enumerate(train_loader)):
        samples, ref_stft, target = minibatch
        # print(f"hello I am a minibatch! my dimensions are:")
        # print(f"samples.shape={samples.shape}\nref_stft.shape={ref_stft.shape}\ntarget.shape={target.shape}")
        # Forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(samples)
        # TODO
        # output_directions = torch.dot(outputs, ref_stft * ref_stft.T)
        # output_angle = torch.argmax(output_directions, axis=1)
        loss = criterion(outputs, target // ANGLE_RES)
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0


--- epoch 0 ---


1it [00:05,  5.18s/it]


--- epoch 1 ---


1it [00:05,  5.29s/it]


--- epoch 2 ---


1it [00:12, 12.97s/it]


--- epoch 3 ---


1it [00:06,  6.41s/it]


--- epoch 4 ---


1it [00:05,  5.26s/it]


--- epoch 5 ---


0it [00:05, ?it/s]


KeyboardInterrupt: 