In [1]:
import os
from os.path import join
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from astropy.io import fits
import pyxis.torch as pxt

from networks import ForkCNN
from train import *
from dataset import FiberDataset
import config

data_dir = '/xdisk/timeifler/wxs0703/kl_nn/train_data_massive/train_database'
fits_dir = '/xdisk/timeifler/wxs0703/kl_nn/train_data/'
samp_dir = '/xdisk/timeifler/wxs0703/kl_nn/samples/samples_massive.csv'
fig_dir = '/xdisk/timeifler/wxs0703/kl_nn/figures/'
model_dir = '/xdisk/timeifler/wxs0703/kl_nn/model/'

In [2]:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(0)
init_process_group(backend='nccl', rank=0, world_size=1)

In [3]:
ds = pxt.TorchDataset(data_dir)

In [4]:
data_args = list(config.data.values())
ds = FiberDataset(*data_args)

In [4]:
save_every = 1
nepochs = 20
batch_size = 100

In [5]:
dl = DataLoader(
        ds,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(ds),
        num_workers=4,
    )



In [6]:
model = ForkCNN(batch_size)
model.to(0)

ForkCNN(
  (cnn_img): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation

In [9]:
start = time.time()
for i, batch in enumerate(dl):
    if i == 0:
        img = torch.unsqueeze(batch['img'], 1).float().to(0)
        spec = torch.unsqueeze(batch['spec'], 1).float().to(0)
        fid = batch['fid_pars'].float().view(-1,8).to(0)
        print(img.shape, spec.shape, fid.shape)
        out = model.forward(img, spec)
    print(f'Batch {i} finished')
    if i==200:
        break
t = time.time()-start
print(t)

torch.Size([100, 1, 48, 48]) torch.Size([100, 1, 3, 64]) torch.Size([100, 8])


  return F.conv2d(input, weight, bias, self.stride,


Batch 0 finished
Batch 1 finished
Batch 2 finished
Batch 3 finished
Batch 4 finished
Batch 5 finished
Batch 6 finished
Batch 7 finished
Batch 8 finished
Batch 9 finished
Batch 10 finished
Batch 11 finished
Batch 12 finished
Batch 13 finished
Batch 14 finished
Batch 15 finished
Batch 16 finished
Batch 17 finished
Batch 18 finished
Batch 19 finished
Batch 20 finished
Batch 21 finished
Batch 22 finished
Batch 23 finished
Batch 24 finished
Batch 25 finished
Batch 26 finished
Batch 27 finished
Batch 28 finished
Batch 29 finished
Batch 30 finished
Batch 31 finished
Batch 32 finished
Batch 33 finished
Batch 34 finished
Batch 35 finished
Batch 36 finished
Batch 37 finished
Batch 38 finished
Batch 39 finished
Batch 40 finished
Batch 41 finished
Batch 42 finished
Batch 43 finished
Batch 44 finished
Batch 45 finished
Batch 46 finished
Batch 47 finished
Batch 48 finished
Batch 49 finished
Batch 50 finished
Batch 51 finished
Batch 52 finished
Batch 53 finished
Batch 54 finished
Batch 55 finished
Ba

In [26]:
destroy_process_group()

In [2]:
world_size = torch.cuda.device_count()
save_every = 1
nepochs = config.train['epoch_number']
batch_size = config.train['batch_size']
nfeatures = config.train['feature_number']

In [None]:
mp.spawn(train_nn, args=(world_size, save_every, nepochs, batch_size, nfeatures), nprocs=world_size)

  return F.conv2d(input, weight, bias, self.stride,
  return F.conv2d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
INFO:Trainer:[TRAIN] Epoch: 1 Loss: 0.40047783193759584 Time: 100:33
INFO:Trainer:[VALID] Epoch: 1 Loss: 0.3603931525195735 Time: 16:12
INFO:Trainer:[TRAIN] Epoch: 2 Loss: 0.36402058187126674 Time: 112:25
INFO:Trainer:[VALID] Epoch: 2 Loss: 0.36181415502591757 Time: 15:28
INFO:Trainer:[TRAIN] Epoch: 3 Loss: 0.360738694012794 Time: 97:40
INFO:Trainer:[VALID] Epoch: 3 Loss: 0.353107712774037 Time: 15:30
INFO:Trainer:[TRAIN] Epoch: 4 Loss: 0.3588963472316023 Time: 96:45
INFO:Trainer:[VALID] Epoch: 4 Loss: 0.3520345464976213 Time: 15:25
INFO:Trainer:[TRAIN] Epoch: 5 Loss: 0.35764340747123236 Time: 96:17
INFO:Trainer:[VALID] Epoch: 5 Loss: 0.35075862735108415 Time: 15:32
