In [86]:
import tqdm

In [87]:
import torchvision

transform = torchvision.transforms.Compose([
    torchvision.transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [88]:
import torch 

NUM_TIMESTEPS = 1024

model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 3)

device = torch.device('cuda')
model = model.to(device)

In [89]:
optimizer = torch.optim.Adam(model.parameters())

In [90]:
from xy_dataset import XYDataset
from circuit_dataset import CircuitDataset

circuit_dataset = CircuitDataset('circuit_log', NUM_TIMESTEPS, gamma=1e-2, transform=transform)
xy_dataset = XYDataset('apex_dataset', transform=transform, random_hflip=True)

In [91]:
def get_indices(N_A, N_B):
    if N_A > N_B:
        idx_a = torch.randperm(N_A)
        idx_b = torch.randperm(N_B)
        while len(idx_b) < N_A:
            idx_b = torch.cat([idx_b, torch.randperm(N_B)])
        idx_b = idx_b[0:N_A]
    else:
        idx_a = torch.randperm(N_A)
        idx_b = torch.randperm(N_B)
        while len(idx_a) < N_B:
            idx_a = torch.cat([idx_a, torch.randperm(N_A)])
        idx_a = idx_a[0:N_B]

    return idx_a, idx_b

In [None]:
import torch.nn.functional as F

EPOCHS = 30
BATCH_SIZE = 32

for epoch in range(EPOCHS):
    
    epoch_loss = 0.0
    circuit_idx, xy_idx = get_indices(len(circuit_dataset), len(xy_dataset))
    train_loss = 0.0
    for idx in tqdm.tqdm(range(0, len(circuit_idx), BATCH_SIZE)):
        
        idx = 0   
        batch_circuit_idx = circuit_idx[idx:idx + BATCH_SIZE]
        batch_xy_idx = xy_idx[idx:idx + BATCH_SIZE]

        batch_images = []
        batch_targets = []
        batch_mask = []

        # load xy images / target
        for i in batch_circuit_idx:
            mask = torch.Tensor([0.0, 0.0, 1.0])
            img, t = circuit_dataset[i]
            target = torch.zeros(3)
            target[2] = t
            batch_images.append(img)
            batch_targets.append(target)
            batch_mask.append(mask)

        for i in batch_xy_idx:
            mask = torch.Tensor([1.0, 1.0, 0.0])
            img, t = xy_dataset[i]
            target = torch.zeros(3)
            target[0:2] = t
            batch_images.append(img)
            batch_targets.append(target)
            batch_mask.append(mask)

        images = torch.stack(batch_images).to(device)
        targets = torch.stack(batch_targets).to(device)
        mask = torch.stack(batch_mask).to(device)
        
        optimizer.zero_grad()
        
        output = model(images)
        
        loss = torch.mean(mask * (targets - output)**2)
        
        train_loss += float(loss)
        loss.backward()
        
        optimizer.step()
    
    print('%d, %f' % (epoch, train_loss / len(circuit_idx)))

100%|██████████| 28/28 [00:27<00:00,  1.02it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

0, 0.000449


100%|██████████| 28/28 [00:27<00:00,  1.01it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

1, 0.000479


100%|██████████| 28/28 [00:27<00:00,  1.02it/s]
  0%|          | 0/28 [00:00<?, ?it/s]

2, 0.000333


100%|██████████| 28/28 [00:27<00:00,  1.03s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

3, 0.000279


100%|██████████| 28/28 [00:29<00:00,  1.04s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

4, 0.000239


100%|██████████| 28/28 [00:29<00:00,  1.05s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

5, 0.000153


100%|██████████| 28/28 [00:29<00:00,  1.06s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

6, 0.000169


100%|██████████| 28/28 [00:29<00:00,  1.01s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

7, 0.000144


100%|██████████| 28/28 [00:29<00:00,  1.05s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

8, 0.000213


100%|██████████| 28/28 [00:29<00:00,  1.05s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

9, 0.000105


100%|██████████| 28/28 [00:29<00:00,  1.04s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

10, 0.000094


100%|██████████| 28/28 [00:29<00:00,  1.06s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

11, 0.000132


100%|██████████| 28/28 [00:29<00:00,  1.05s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

12, 0.000111


100%|██████████| 28/28 [00:29<00:00,  1.05s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

13, 0.000115


100%|██████████| 28/28 [00:29<00:00,  1.04s/it]
  0%|          | 0/28 [00:00<?, ?it/s]

14, 0.000134


 61%|██████    | 17/28 [00:17<00:11,  1.04s/it]