In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import os
from tqdm import tqdm
import torch.nn as nn

In [2]:
os.makedirs('new_samples',exist_ok=True)
os.makedirs('recon_samples',exist_ok=True)

In [3]:
class AverageMeter(object):
    """Computes and stores the average and current value.

    Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py
    """
    def __init__(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def mean_dim(tensor, dim=None, keepdims=False):
    if dim is None:
        return tensor.mean()
    else:
        if isinstance(dim, int):
            dim = [dim]
        dim = sorted(dim)
        for d in dim:
            tensor = tensor.mean(dim=d, keepdim=True)
        if not keepdims:
            for i, d in enumerate(dim):
                tensor.squeeze_(d-i)
        return tensor

In [4]:
class act_norm(torch.nn.Module):
  def __init__(self,n_feats):
    super(act_norm,self).__init__()
    self.register_buffer('initialized',torch.zeros(1))
    self.mean=torch.nn.Parameter(torch.zeros(1,n_feats,1,1))
    self.s=torch.nn.Parameter(torch.zeros(1,n_feats,1,1))
    self.n_feats=n_feats

  def init_params(self,x):
    #actnorm is initialized so output of coupling layer had null mean and unit variance
    #scale the data  with the negative of it's mean and inverse of sqrt of variance
    if self.training == 0:
      return None
    with torch.no_grad():
      mean=-mean_dim(x.clone(),dim=[0,2,3],keepdims=True)
      var=mean_dim((x.clone()+mean) ** 2, dim=[0, 2, 3], keepdims=True)
      #mean=-x.mean(dim=[0,2,3],keepdims=True)
      #var=((x+mean)**2).mean(dim=[0,2,3],keepdims=True)
      s=(float(1.)/var.sqrt()+(1e-6)).log()
      self.mean.data.copy_(mean.data)
      self.s.data.copy_(s.data)
      self.initialized+=1.

  def forward(self,x,log_det):
    if self.initialized == 0:
      self.init_params(x)
    x=x+self.mean
    x=x*torch.exp(self.s)
    #h*w*sum(log(s))
    log_det_jac=x.shape[2]*x.shape[3]*torch.sum(self.s)
    if log_det is not None:
      log_det+=log_det_jac
    return (x,log_det)
  
  def backward(self,x,log_det):
    if self.initialized == 0:
      self.init_params(x)
    x=x*torch.exp(self.s*-1)
    #h*w*sum(log(s))
    log_det_jac=x.shape[2]*x.shape[3]*torch.sum(self.s)
    if log_det is not None:
      log_det-=log_det_jac
    x=x-self.mean
    return (x,log_det)

In [5]:
class invertible_conv(torch.nn.Module):
  def __init__(self,n_channels):
    super(invertible_conv,self).__init__()
    self.n_channels=n_channels
    #weight matrix for 1x1 inertible convolution has to be chann x chann and orthogonal(det=0)
    #this is reshaped into chann x chann where each row becomes (channelx1x1==>n*h*w) filter
    #There are n_channel such matrices of this diemsnion
    # the purpose of the invertible convolution is to permute the channels of the image instead of just splitting and \
    # then applying transformation. This can be done by using a fixed permutation matrix(weight of the 1d conv matrix) but
    # GLOW decided to make it learnable instead.
    #refer to notes    
    weight_matrix=np.random.randn(n_channels,n_channels)
    weight_matrix=np.linalg.qr(weight_matrix)[0]
    weight_matrix=weight_matrix.astype("float32")
    weight_matrix=torch.from_numpy(weight_matrix)
    self.filter=torch.nn.Parameter(weight_matrix)
  
  def forward(self,x,log_det):
    #h*w*log|det(W)|
    log_det_jacobian=x.shape[2]*x.shape[3]*torch.slogdet(self.filter)[1]
    #print(log_det_jacobian)
    log_det+=log_det_jacobian
    filter=self.filter.view(self.n_channels,self.n_channels,1,1)
    #ref_weight=ref_weight.view(self.n_channels,self.n_channels,1,1)
    op=torch.nn.functional.conv2d(x,filter)
    return (op,log_det)

  def backward(self,x,log_det):
    log_det_jacobian=x.shape[2]*x.shape[3]*torch.slogdet(self.filter)[1]
    log_det-=log_det_jacobian
    inv_filter=torch.inverse(self.filter.double()).float()
    #ref_weight=torch.inverse(ref_weight.double()).float()
    inv_filter=inv_filter.view(self.n_channels,self.n_channels,1,1)
    #ref_weight=ref_weight.view(self.n_channels,self.n_channels,1,1)

    op=torch.nn.functional.conv2d(x,inv_filter)
    return (op,log_det)

In [6]:
class coupling_layer(torch.nn.Module):
  def __init__(self,in_channel,mid_channel,norm_type="batch_norm"):
    super(coupling_layer,self).__init__()
    self.scale=torch.nn.Parameter(torch.ones(in_channel,1,1))
    out_channel=2*in_channel
    if norm_type == "act_norm":
      self.norm1=act_norm(in_channel)
    else:
      self.norm1=torch.nn.BatchNorm2d(in_channel)
    self.conv1=torch.nn.Conv2d(in_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    nn.init.normal_(self.conv1.weight,0.,0.05)

    if norm_type == "act_norm":
      self.norm2=act_norm(mid_channel)
    else:
      self.norm2=torch.nn.BatchNorm2d(mid_channel)
    self.conv2=torch.nn.Conv2d(mid_channel,mid_channel,kernel_size=(1,1),padding=0,bias=False)
    nn.init.normal_(self.conv2.weight,0.,0.05)

    if norm_type=="act_norm":
      self.norm3=act_norm(mid_channel)
    else:
      self.norm3=torch.nn.BatchNorm2d(mid_channel)
    self.conv3=torch.nn.Conv2d(mid_channel,out_channel,kernel_size=(3,3),padding=1,bias=True)
    nn.init.zeros_(self.conv3.weight)
    nn.init.zeros_(self.conv3.bias)

  def forward(self,x,log_det):
    #split the tensor according to channel dimension
    # transform one half and conctanate with the unchanged half
    x1,x2=x.chunk(2,dim=1)
    
    x3=self.norm1(x2)
    x3=F.relu(x3)
    x3=self.conv1(x3)
    x3=self.norm2(x3)
    x3=F.relu(x3)
    x3=self.conv2(x3)
    x3=self.norm3(x3)
    x3=F.relu(x3)
    x3=self.conv3(x3)

    #for finding the scale and translation parameters alternate channels
    #from the transformation are used instead of just splitting
    #experiment with splitting the channels into 2 halves 
    s=x3[:,0::2,...]
    t=x3[:,1::2,...]
    #s,t=torch.split(x3,int(x3.shape[1]/2),dim=1)

    s=self.scale*torch.tanh(s)
    x1=(x1+t)*torch.exp(s)
    #x1=(x1+t)*torch.exp(s)
    op=torch.cat((x1,x2),dim=1)
    log_jac=torch.sum(s,dim=(1,2,3))
    log_det=log_det + log_jac
    return (op,log_det) 

  def backward(self,x,log_det):
    #split the tensor according to channel dimension
    # transform one half and conctanate with the unchanged half
    x1,x2=x.chunk(2,dim=1)
    
    x3=self.norm1(x2)
    x3=F.relu(x3)
    x3=self.conv1(x3)
    x3=self.norm2(x3)
    x3=F.relu(x3)
    x3=self.conv2(x3)
    x3=self.norm3(x3)
    x3=F.relu(x3)
    x3=self.conv3(x3)

    #for finding the scale and translation parameters alternate channels
    #from the transformation are used instead of just splitting
    #experiment with splitting the channels into 2 halves 
    s=x3[:,0::2,:,:]
    t=x3[:,1::2,:,:]
    #s,t=torch.split(x3,int(x3.shape[1]/2),dim=1)
    
    s=self.scale*torch.tanh(s)
    x1=x1*torch.exp(-s) - t
    #x1=(x1+t)*torch.exp(s)
    op=torch.cat((x1,x2),dim=1)
    log_jac=torch.sum(s,dim=(1,2,3))
    log_det=log_det - log_jac
    return (op,log_det) 

In [7]:
class flow_module(torch.nn.Module):
  ## act_norm ==> invertoble 1x1 conv ==> coupling layer(affine) ##
  def __init__(self,in_channel,mid_channel):
    super(flow_module,self).__init__()

    self.norm1=act_norm(in_channel)
    self.conv1=invertible_conv(in_channel)
    self.coupling1=coupling_layer(in_channel//2,mid_channel)

  def forward(self,x,log_det):
    x,log_det=self.norm1.forward(x,log_det)
    x,log_det=self.conv1.forward(x,log_det)
    x,log_det=self.coupling1.forward(x,log_det)

    return (x,log_det)
  
  def backward(self,x,log_det):
    x,log_det=self.coupling1.backward(x,log_det)
    x,log_det=self.conv1.backward(x,log_det)
    x,log_det=self.norm1.backward(x,log_det)

    return (x,log_det)

In [8]:
def squeeze(x):
  b,c,h,w=x.size()
  x=x.view(b,c,h//2,2,w//2,2)
  x=x.permute(0,1,3,5,2,4).contiguous()
  x=x.view(b,c*2*2,h//2,w//2)
  return (x)
    
def unsqueeze(x):
  # Unsqueeze
  b,c,h,w=x.size()
  x=x.view(b,c//4,2,2,h,w)
  x=x.permute(0,1,4,2,5,3).contiguous()
  x=x.view(b,c//4,h*2,w*2)
  return (x) 

In [9]:
class glow(torch.nn.Module):
  def __init__(self,in_channel,mid_channel,L,K):
    super(glow,self).__init__()
    self.glow_flows_1=torch.nn.ModuleList([flow_module(in_channel,mid_channel) for _ in range(K)])
    if L>1:
      self.glow_flows_2=glow(2*in_channel,mid_channel,L-1,K)
    else:
      self.glow_flows_2=None
  
  def forward(self,x,log_det):
    for block in self.glow_flows_1:
      #print(block)
      x,log_det=block(x,log_det)
    if self.glow_flows_2 is not None:
      x=squeeze(x)
      x,x2=x.chunk(2,dim=1)
      x,log_det=self.glow_flows_2(x,log_det)
      x=torch.cat((x,x2),dim=1)
      x=unsqueeze(x)

    return (x,log_det)
  
  def backward(self,x,log_det):
    if self.glow_flows_2 is not None:
      x=squeeze(x)
      x,x2=x.chunk(2,dim=1)
      x,log_det=self.glow_flows_2(x,log_det)
      x=torch.cat((x,x2),dim=1)
      x=unsqueeze(x)
    
    for block in self.glow_flows_1[::-1]:
      x,log_det=block.backward(x,log_det)

    return (x,log_det)

In [10]:
def preprocess(x):
  noise=torch.distributions.Uniform(0.,1.).sample(x.shape)
  noise=noise.to(device)
  x=(x*255. + noise)/256.
  x*=2.
  x-=1.
  x*=0.9
  x+=1.
  x/=2.
  logit_x=torch.log(x) - torch.log(1.-x)
  pre_logit_scale=torch.tensor(np.log(0.9) - np.log(1.-0.9))
  log_det=F.softplus(logit_x) + F.softplus(-logit_x) -F.softplus(-pre_logit_scale)
  log_det=torch.sum(log_det,dim=(1,2,3))
  x=torch.log(x)-torch.log(1.-x)
  return (x,log_det)

In [11]:
class glow_model(torch.nn.Module):
  def __init__(self,prior_dist,n_channels,L,K):
    super(glow_model,self).__init__()
    self.prior=prior_dist
    self.model=glow(in_channel=4*1,mid_channel=n_channels,L=L,K=K)#for rgb 4*3, for bw 4*1

  def inference(self,x):#X ==> Z
    x,log_det=preprocess(x)
    x=squeeze(x)
    x,log_det=self.model.forward(x,log_det)
    x=unsqueeze(x)
    return (x,log_det)
  
  def sampling(self,x):#Z ==> X
    log_det=torch.zeros(x.shape[0])
    log_det=log_det.to(device)
    z=squeeze(x)
    z,log_det=self.model.backward(z,log_det)
    z=unsqueeze(z)
    return (z)
  
  def sample_images(self,number,channel,height,width):
    z=torch.randn((number,channel,height,width),dtype=torch.float32)
    #z=self.prior.sample((number,channel,height,width))
    z=z.to(device)
    x=self.sampling(z)
    return (x)
  
  def likelihood(self,x):
    #log(p(x))=log(ph(f(x))+log(sii)  
    x_,log_det=self.inference(x)
    prior_ll= -0.5 * (x_**2 + np.log(2*np.pi))
    prior_ll = prior_ll.flatten(1).sum(-1) - np.log(256) * np.prod(x_.size()[1:])
    #log_ll=torch.sum(self.prior.log_prob(x_),dim=[1,2,3])
    #log_ll=log_ll-np.log(256)*np.prod(x_.size()[1:])
    return (prior_ll+log_det) 

  def forward(self,x):
    ll=self.likelihood(x)
    return (ll)

In [12]:
#dataloader
bs=64
transform_train=transforms.Compose([transforms.ToTensor(),transforms.Resize((32,32))])
transform_test=transforms.Compose([transforms.ToTensor(),torchvision.transforms.Resize((32,32))])

trainset=torchvision.datasets.FashionMNIST(root='./data', train=True,download=True,transform=transform_train)
train_loader=torch.utils.data.DataLoader(trainset,batch_size=bs,shuffle=True, num_workers=2)

testset=torchvision.datasets.FashionMNIST(root='./data', train=False,download=True, transform=transform_test)
test_loader=torch.utils.data.DataLoader(testset, batch_size=bs,shuffle=False,num_workers=2)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [13]:
def bits_per_dim(x, nll):
  dim = np.prod(x.size()[1:])
  bpd = nll / (np.log(2) * dim)
  return bpd

In [14]:
mean=torch.tensor((0.))
variance=torch.tensor((1.))
gaussian_dist=torch.distributions.normal.Normal(mean,variance)
Glow_net=glow_model(prior_dist=gaussian_dist,n_channels=512,L=3,K=16)
Glow_net=Glow_net.to(device)
optimizer=torch.optim.Adam(Glow_net.parameters(),lr=1e-3,betas=(0.9,0.999),eps=1e-8)
scheduler=optim.lr_scheduler.LambdaLR(optimizer, lambda s: min(1.,s/500000))

In [15]:
@torch.enable_grad()
def train_model(epoch,Glow_net,train_loader,device,optimizer,scheduler):
  global global_step
  print('\nEpoch: %d'%epoch)
  Glow_net.train()
  loss_meter=AverageMeter()
  with tqdm(total=len(train_loader.dataset)) as progress_bar:
    for image,label in train_loader:
      x=image.to(device)
      optimizer.zero_grad()
      loss=-Glow_net(x).mean()
      loss_meter.update(loss.item(),x.size(0))
      loss.backward()
      optimizer.step()
      scheduler.step()
      progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x,loss_meter.avg),lr=optimizer.param_groups[0]["lr"])
      progress_bar.update(x.size(0))
      global_step+=x.size(0)

In [16]:
@torch.no_grad()
def test_model(epoch,Glow_net,train_loader,device,optimizer,scheduler):
  #global test_loss
  Glow_net.eval()
  loss_meter=AverageMeter()
  with tqdm(total=len(test_loader.dataset)) as progress_bar:
        for image,label in test_loader:
            x=image.to(device)
            z,det=Glow_net.inference(x)
            prior= -0.5 * (z**2 + np.log(2*np.pi))
            #log_ll=torch.sum(Glow_net.prior.log_prob(z),dim=[1,2,3])
            prior_ll=prior.flatten(1).sum(-1) - np.log(256) * np.prod(z.size()[1:])
            #log_ll=log_ll-np.log(256)*np.prod(z.size()[1:])
            loss=-(prior_ll+det).mean()
            loss_meter.update(loss.item(),x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x,loss_meter.avg),)
            progress_bar.update(x.size(0))
  images=Glow_net.sample_images(64,1,32,32)
  images1=torch.sigmoid(images)
  #os.makedirs('samples',exist_ok=True)
  images_concat=torchvision.utils.make_grid(images,nrow=int(64 ** 0.5), padding=2, pad_value=255)
  torchvision.utils.save_image(images_concat,'new_samples/epoch_{}.png'.format(epoch))
  images_concat1=torchvision.utils.make_grid(images1,nrow=int(64 ** 0.5), padding=2, pad_value=255)
  torchvision.utils.save_image(images_concat1,'new_samples/epoch1_{}.png'.format(epoch))
  #return (images)

## BEFORE RESTARTING TRAINING ADD MEAN.PT FILE AND RECHECK GENARATED RESULTS

In [None]:
global_step=0
start_epoch=0
num_epochs=100
for epoch in range(start_epoch, start_epoch + num_epochs):
  train_model(epoch,Glow_net,train_loader,device,optimizer,scheduler)
  test_model(epoch,Glow_net,test_loader,device,optimizer,scheduler)

  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 0


100%|██████████| 60000/60000 [07:35<00:00, 131.77it/s, bpd=4.8, lr=1.88e-6, nll=3.41e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.06it/s, bpd=3.74, nll=2.66e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 1


100%|██████████| 60000/60000 [07:35<00:00, 131.64it/s, bpd=3.55, lr=3.75e-6, nll=2.52e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.91it/s, bpd=3.4, nll=2.41e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 2


100%|██████████| 60000/60000 [07:34<00:00, 132.02it/s, bpd=3.34, lr=5.63e-6, nll=2.37e+3]
100%|██████████| 10000/10000 [00:29<00:00, 334.05it/s, bpd=3.27, nll=2.32e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 3


100%|██████████| 60000/60000 [07:35<00:00, 131.83it/s, bpd=3.24, lr=7.5e-6, nll=2.3e+3]
100%|██████████| 10000/10000 [00:29<00:00, 335.36it/s, bpd=3.18, nll=2.26e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 4


100%|██████████| 60000/60000 [07:34<00:00, 132.05it/s, bpd=3.17, lr=9.38e-6, nll=2.25e+3]
100%|██████████| 10000/10000 [00:29<00:00, 336.21it/s, bpd=3.19, nll=2.26e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 5


100%|██████████| 60000/60000 [07:34<00:00, 132.03it/s, bpd=3.12, lr=1.13e-5, nll=2.21e+3]
100%|██████████| 10000/10000 [00:29<00:00, 334.74it/s, bpd=3.1, nll=2.2e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 6


100%|██████████| 60000/60000 [07:36<00:00, 131.51it/s, bpd=3.09, lr=1.31e-5, nll=2.19e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.82it/s, bpd=3.06, nll=2.17e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 7


100%|██████████| 60000/60000 [07:36<00:00, 131.42it/s, bpd=3.05, lr=1.5e-5, nll=2.17e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.90it/s, bpd=3.04, nll=2.16e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 8


100%|██████████| 60000/60000 [07:36<00:00, 131.54it/s, bpd=3.03, lr=1.69e-5, nll=2.15e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.05it/s, bpd=3.02, nll=2.15e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 9


100%|██████████| 60000/60000 [07:36<00:00, 131.31it/s, bpd=3.01, lr=1.88e-5, nll=2.14e+3]
100%|██████████| 10000/10000 [00:29<00:00, 334.70it/s, bpd=3.06, nll=2.17e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 10


100%|██████████| 60000/60000 [07:36<00:00, 131.47it/s, bpd=2.99, lr=2.06e-5, nll=2.12e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.04it/s, bpd=3.05, nll=2.16e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 11


100%|██████████| 60000/60000 [07:38<00:00, 130.74it/s, bpd=2.98, lr=2.25e-5, nll=2.11e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.48it/s, bpd=3, nll=2.13e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 12


100%|██████████| 60000/60000 [07:40<00:00, 130.42it/s, bpd=2.97, lr=2.44e-5, nll=2.11e+3]
100%|██████████| 10000/10000 [00:30<00:00, 328.87it/s, bpd=2.96, nll=2.1e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 13


100%|██████████| 60000/60000 [07:40<00:00, 130.28it/s, bpd=2.95, lr=2.63e-5, nll=2.1e+3]
100%|██████████| 10000/10000 [00:30<00:00, 330.14it/s, bpd=2.97, nll=2.11e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 14


100%|██████████| 60000/60000 [07:40<00:00, 130.29it/s, bpd=2.94, lr=2.81e-5, nll=2.09e+3]
100%|██████████| 10000/10000 [00:30<00:00, 330.21it/s, bpd=2.99, nll=2.12e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 15


100%|██████████| 60000/60000 [07:38<00:00, 130.78it/s, bpd=2.93, lr=3e-5, nll=2.08e+3]
100%|██████████| 10000/10000 [00:30<00:00, 330.05it/s, bpd=3.02, nll=2.14e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 16


100%|██████████| 60000/60000 [07:38<00:00, 130.88it/s, bpd=2.92, lr=3.19e-5, nll=2.07e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.29it/s, bpd=3.01, nll=2.14e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 17


100%|██████████| 60000/60000 [07:36<00:00, 131.33it/s, bpd=2.91, lr=3.38e-5, nll=2.07e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.65it/s, bpd=2.95, nll=2.09e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 18


100%|██████████| 60000/60000 [07:39<00:00, 130.59it/s, bpd=2.91, lr=3.56e-5, nll=2.06e+3]
100%|██████████| 10000/10000 [00:30<00:00, 328.62it/s, bpd=2.95, nll=2.09e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 19


100%|██████████| 60000/60000 [07:37<00:00, 131.08it/s, bpd=2.89, lr=3.75e-5, nll=2.05e+3]
100%|██████████| 10000/10000 [00:30<00:00, 333.17it/s, bpd=3.02, nll=2.14e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 20


100%|██████████| 60000/60000 [07:37<00:00, 131.20it/s, bpd=2.89, lr=3.94e-5, nll=2.05e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.00it/s, bpd=2.92, nll=2.07e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 21


100%|██████████| 60000/60000 [07:38<00:00, 130.89it/s, bpd=2.88, lr=4.13e-5, nll=2.05e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.68it/s, bpd=2.86, nll=2.03e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 22


100%|██████████| 60000/60000 [07:37<00:00, 131.07it/s, bpd=2.88, lr=4.31e-5, nll=2.04e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.86it/s, bpd=2.93, nll=2.08e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 23


100%|██████████| 60000/60000 [07:39<00:00, 130.49it/s, bpd=2.87, lr=4.5e-5, nll=2.04e+3]
100%|██████████| 10000/10000 [00:30<00:00, 329.01it/s, bpd=2.95, nll=2.09e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 24


100%|██████████| 60000/60000 [07:40<00:00, 130.17it/s, bpd=2.86, lr=4.69e-5, nll=2.03e+3]
100%|██████████| 10000/10000 [00:30<00:00, 327.98it/s, bpd=2.95, nll=2.09e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 25


100%|██████████| 60000/60000 [07:41<00:00, 129.97it/s, bpd=2.86, lr=4.88e-5, nll=2.03e+3]
100%|██████████| 10000/10000 [00:30<00:00, 329.14it/s, bpd=2.9, nll=2.05e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 26


100%|██████████| 60000/60000 [07:38<00:00, 130.96it/s, bpd=2.85, lr=5.07e-5, nll=2.02e+3]
100%|██████████| 10000/10000 [00:30<00:00, 330.79it/s, bpd=2.93, nll=2.08e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 27


100%|██████████| 60000/60000 [07:39<00:00, 130.45it/s, bpd=2.85, lr=5.25e-5, nll=2.02e+3]
100%|██████████| 10000/10000 [00:30<00:00, 328.43it/s, bpd=2.89, nll=2.05e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 28


100%|██████████| 60000/60000 [07:41<00:00, 130.12it/s, bpd=2.84, lr=5.44e-5, nll=2.02e+3]
100%|██████████| 10000/10000 [00:30<00:00, 327.05it/s, bpd=2.93, nll=2.08e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 29


100%|██████████| 60000/60000 [07:39<00:00, 130.55it/s, bpd=2.83, lr=5.63e-5, nll=2.01e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.04it/s, bpd=2.96, nll=2.1e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 30


100%|██████████| 60000/60000 [07:37<00:00, 131.27it/s, bpd=2.83, lr=5.82e-5, nll=2.01e+3]
100%|██████████| 10000/10000 [00:30<00:00, 333.16it/s, bpd=3.03, nll=2.15e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 31


100%|██████████| 60000/60000 [07:38<00:00, 130.91it/s, bpd=2.82, lr=6e-5, nll=2e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.16it/s, bpd=2.95, nll=2.09e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 32


100%|██████████| 60000/60000 [07:37<00:00, 131.13it/s, bpd=2.82, lr=6.19e-5, nll=2e+3]
100%|██████████| 10000/10000 [00:29<00:00, 333.92it/s, bpd=2.89, nll=2.05e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 33


100%|██████████| 60000/60000 [07:37<00:00, 131.08it/s, bpd=2.82, lr=6.38e-5, nll=2e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.54it/s, bpd=2.88, nll=2.05e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 34


100%|██████████| 60000/60000 [07:39<00:00, 130.61it/s, bpd=2.81, lr=6.57e-5, nll=1.99e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.22it/s, bpd=2.89, nll=2.05e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 35


100%|██████████| 60000/60000 [07:38<00:00, 130.94it/s, bpd=2.81, lr=6.75e-5, nll=1.99e+3]
100%|██████████| 10000/10000 [00:30<00:00, 330.62it/s, bpd=2.83, nll=2.01e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 36


100%|██████████| 60000/60000 [07:38<00:00, 130.89it/s, bpd=2.81, lr=6.94e-5, nll=1.99e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.35it/s, bpd=2.84, nll=2.02e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 37


100%|██████████| 60000/60000 [07:39<00:00, 130.48it/s, bpd=2.8, lr=7.13e-5, nll=1.99e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.85it/s, bpd=2.85, nll=2.02e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 38


100%|██████████| 60000/60000 [07:36<00:00, 131.38it/s, bpd=2.8, lr=7.32e-5, nll=1.98e+3]
100%|██████████| 10000/10000 [00:29<00:00, 333.35it/s, bpd=2.82, nll=2e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 39


100%|██████████| 60000/60000 [07:37<00:00, 131.07it/s, bpd=2.79, lr=7.5e-5, nll=1.98e+3]
100%|██████████| 10000/10000 [00:30<00:00, 331.57it/s, bpd=2.92, nll=2.07e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 40


100%|██████████| 60000/60000 [07:36<00:00, 131.40it/s, bpd=2.79, lr=7.69e-5, nll=1.98e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.27it/s, bpd=2.83, nll=2.01e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 41


100%|██████████| 60000/60000 [07:36<00:00, 131.56it/s, bpd=2.78, lr=7.88e-5, nll=1.97e+3]
100%|██████████| 10000/10000 [00:29<00:00, 333.72it/s, bpd=2.81, nll=1.99e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 42


100%|██████████| 60000/60000 [07:36<00:00, 131.33it/s, bpd=2.78, lr=8.07e-5, nll=1.97e+3]
100%|██████████| 10000/10000 [00:29<00:00, 333.90it/s, bpd=2.93, nll=2.08e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 43


100%|██████████| 60000/60000 [07:33<00:00, 132.36it/s, bpd=2.77, lr=8.25e-5, nll=1.97e+3]
100%|██████████| 10000/10000 [00:30<00:00, 332.63it/s, bpd=2.83, nll=2.01e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 44


100%|██████████| 60000/60000 [07:33<00:00, 132.30it/s, bpd=2.77, lr=8.44e-5, nll=1.97e+3]
100%|██████████| 10000/10000 [00:29<00:00, 339.75it/s, bpd=2.95, nll=2.1e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 45


100%|██████████| 60000/60000 [07:30<00:00, 133.33it/s, bpd=2.76, lr=8.63e-5, nll=1.96e+3]
100%|██████████| 10000/10000 [00:29<00:00, 337.97it/s, bpd=2.88, nll=2.04e+3]
  0%|          | 0/60000 [00:00<?, ?it/s]


Epoch: 46


100%|██████████| 60000/60000 [07:29<00:00, 133.58it/s, bpd=2.76, lr=8.82e-5, nll=1.96e+3]
 54%|█████▍    | 5376/10000 [00:15<00:13, 345.79it/s, bpd=2.77, nll=1.97e+3]