In [2]:
import utilities
import torch
t = utilities.Tokenizer()
import numpy as np
import random

import torch.nn as nn

from collections import namedtuple

In [38]:
class Attention(nn.Module):
    
    def __init__(self, Nx, c1, Ny, c2, causal_mask = False, N_heads = 16, w = 4):
        super().__init__()
        self.ln1 = nn.LayerNorm([c1])
        self.ln2 = nn.LayerNorm([c2])
        self.causal_mask = causal_mask
        if not causal_mask:
            self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads, batch_first=True)
        else:
            self.mask = torch.triu(torch.ones((Nx, Ny)))
            self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads, batch_first=True)
        self.ln3 = nn.LayerNorm([Nx, c1])
        self.l1 = nn.Linear(c1, c1*w)
        self.gelu = nn.GELU()
        self.l2 = nn.Linear(c1*w, c1)

    def forward(self, x, y):
        xn = self.ln1(x)
        yn = self.ln2(y)
        if self.causal_mask:
            attn, _ = self.MAH(xn, yn, yn, attn_mask=self.mask)
        else:
            attn, _ = self.MAH(xn, yn, yn)
        x = x + attn
        x = x + self.l2(self.gelu(self.l1(self.ln3(x))))
        return x

In [39]:
attentive_layer = Attention(10, 64, 16, 16, causal_mask = False, N_heads = 16, w = 4)
attentive_layer.forward(torch.randn(64, 10, 64), torch.randn(64, 16, 16)) 

tensor([[[-1.9449, -0.9045, -2.0473,  ..., -2.6309, -0.0160, -0.0098],
         [ 1.3053, -1.2469,  1.3207,  ...,  0.3813,  0.5213,  0.8853],
         [-0.2047, -0.8329, -1.0516,  ...,  1.3030, -1.0672,  0.2985],
         ...,
         [ 0.2413,  0.8696,  0.2104,  ..., -1.3503, -1.6520, -0.8283],
         [ 0.3270,  1.8363, -0.5287,  ...,  1.3792, -0.9385, -2.0533],
         [ 0.0964, -0.7531,  1.4557,  ...,  1.8700,  1.8954,  0.4743]],

        [[-0.3156, -0.2061, -0.5957,  ...,  0.5068,  1.0280,  0.4847],
         [ 0.1622, -0.3383,  1.6691,  ..., -0.2314,  0.1010, -0.6300],
         [-0.8307,  0.0463,  0.9058,  ...,  0.1783,  0.7428, -0.4628],
         ...,
         [-2.7470, -0.4737, -0.0843,  ..., -0.1056, -0.0181, -3.4866],
         [ 1.0064, -0.2828,  0.3554,  ..., -1.3402,  0.4635,  1.7412],
         [-2.1375,  1.0034,  1.0687,  ...,  1.3476, -0.4074, -1.2882]],

        [[-1.5899,  0.9114, -0.4479,  ...,  3.1802,  0.3627, -0.3579],
         [-0.4938, -0.3386, -1.1516,  ..., -1

In [3]:
class Attention(nn.Module):
    
    def __init__(self, c1, c2, causal_mask = False, N_heads = 16, w = 4, device = torch.device('cuda')):
        super().__init__()
        self.ln1 = nn.LayerNorm([c1])  # [Nx, c1]
        self.ln2 = nn.LayerNorm([c2])  # [Ny, c2]
        self.causal_mask = causal_mask
        self.MAH = nn.MultiheadAttention(embed_dim=c1, kdim=c2, vdim=c2, num_heads=N_heads, batch_first=True)
        self.ln3 = nn.LayerNorm([c1])  # [Nx, c1]
        self.l1 = nn.Linear(c1, c1*w)
        self.gelu = nn.GELU()
        self.l2 = nn.Linear(c1*w, c1)
        self.device = device

    def forward(self, x, y):
        xn = self.ln1(x)
        yn = self.ln2(y)
        if self.causal_mask:
            mask = torch.triu(torch.ones(x.shape[1], y.shape[1], dtype=bool), diagonal=1).to(self.device)
            attn = self.MAH(xn, yn, yn, attn_mask = mask)[0]
        else:
            attn = self.MAH(xn, yn, yn)[0]
        x = x + attn
        x = x + self.l2(self.gelu(self.l1(self.ln3(x))))
        return x

class AttentiveModes(nn.Module):
    def __init__(self, s, c):
        super().__init__()
        self.attention = Attention(c, c, N_heads = 8)
        self.s = s
        self.c = c

    def forward(self, x1, x2, x3):
        g = [x1, x2, x3]
        for m1, m2 in [(0, 1), (2, 0), (1, 2)]:
            a = torch.concatenate((g[m1], torch.transpose(g[m2], 1, 2)), axis=2)
            aflat = a.flatten(0,1)
            c = self.attention(aflat, aflat).reshape_as(a)
            g[m1] = c[:, :, :self.s, :]
            g[m2] = c[:, :, self.s:, :].transpose(1,2)
        return g


class Torso(nn.Module):
    def __init__(self, s, c, i):
        super().__init__()
        self.l1 = nn.Linear(s, c)
        self.attentive_modes = nn.ModuleList([AttentiveModes(s, c) for _ in range(i)])
        self.s = s
        self.c = c
        self.i = i

    def forward(self, x):
        x1 = torch.permute(x, (0, 1, 2, 3))
        x2 = torch.permute(x, (0, 2, 3, 1))
        x3 = torch.permute(x, (0, 3, 1, 2))

        x1 = self.l1(x1)
        x2 = self.l1(x2)
        x3 = self.l1(x3)
        
        for am in self.attentive_modes:
            x1, x2, x3 = am(x1, x2, x3)

        e = torch.reshape(torch.stack([x1, x2, x3], axis=2), (-1, 3 * self.s ** 2, self.c))
        return e

class PolicyHead(nn.Module):
    # Currently assumes our implemented tokenization scheme
    # That is, Nstesp = s and Nlogits = range^3
    def __init__(self, Nsteps, elmnt_range, s, c, Nfeatures = 64, Nheads = 16, Nlayers = 2, device = torch.device('cuda')):
        super().__init__()
        self.Nlayers = Nlayers
        self.Nlogits = (elmnt_range[1]-elmnt_range[0]+1)**3
        self.tokenizer = utilities.Tokenizer(elmnt_range)
        self.Nsteps = Nsteps
        self.Nfeatures = Nfeatures
        self.Nheads = Nheads
        self.device = device

        self.tok_embedding = nn.Embedding(self.Nlogits+1, Nfeatures * Nheads)  #In principle more efficient than forming one-hot vectors and matrix multplying
        self.START_TOK = self.Nlogits
        self.pos_embedding = nn.Embedding(Nsteps, Nfeatures * Nheads)

        # I figure if we are keeping the weights in the LayerNorm, we might as well have
        #   a different one for each layer, but idk really
        self.ln1 = nn.ModuleList([nn.LayerNorm([Nfeatures * Nheads]) for _ in range(Nlayers)])  # [Nsteps, Nfeatures * Nheads]
        self.dropout = nn.Dropout(p=0.1)
        self.self_attention = nn.ModuleList([Attention(Nfeatures * Nheads, Nfeatures * Nheads, causal_mask=True, N_heads=Nheads) for _ in range(Nlayers)])
        self.ln2 = nn.ModuleList([nn.LayerNorm([Nfeatures * Nheads]) for _ in range(Nlayers)])
        self.cross_attention = nn.ModuleList([Attention(Nfeatures * Nheads, c, N_heads=Nheads) for _ in range(Nlayers)])
        
        self.relu = nn.ReLU()
        self.lfinal = nn.Linear(Nfeatures * Nheads, self.Nlogits)

    def predict_logits(self, a, e):   # Assumes a is in tokenized, not one-hot form
        x = self.tok_embedding(a)
        positions = torch.arange(a.shape[1]).repeat((a.shape[0], 1)).to(self.device)
        x = x + self.pos_embedding(positions)
        for i in range(self.Nlayers):
            x = self.ln1[i](x)
            c = self.self_attention[i](x, x)
            c = self.dropout(c)  # Does not run if in evaluation mode
            x = x + c
            x = self.ln2[i](x)
            c = self.cross_attention[i](x, e)
            c = self.dropout(c)
            x = x + c
        o = self.lfinal(self.relu(x))
        return o    # Don't need x bc we are not feeding it to the value head
    
    def forward(self, e, **kwargs):
        if self.training:
            g = kwargs['g']
            #I'm not entirely sure this is right -- need to think on tokens and what the null character is
            #g = torch.cat((torch.tensor([0]), g))
            #Not working at the moment, going to stick with this and not shifting, but maybe there's a shift or something needed?
            # a = nn.functional.one_hot(g, self.Nlogits).float()
            # o, z = self.predict_logits(a, e)
            # return o, z
            a = torch.concatenate((torch.tensor(self.START_TOK).repeat(g.shape[0], 1).to(self.device), g[:, :-1].to(self.device)), axis=1).to(self.device)
            return self.predict_logits(a, e)
        
        else:
            Nsamples = kwargs['Nsamples']
            #a = torch.zeros((Nsamples, self.Nsteps)).long()
            a = [[self.START_TOK] for _ in range(Nsamples)]
            p = torch.ones(Nsamples)
            #z = torch.zeros((Nsamples, self.Nsteps, self.Nfeatures * self.Nheads))
            #Don't care about exporting Z anymore
            for j in range(Nsamples):
                for i in range(self.Nsteps):
                    # encoded = nn.functional.one_hot(a[j, :], self.Nlogits)
                    # o, _ = self.predict_logits(encoded.float(), e)
                    o = self.predict_logits(torch.tensor([a[j]]).to(self.device), e)
                    probs = torch.softmax(o[0, i, :], -1).to('cpu')
                    tok = torch.multinomial(probs, num_samples=1).item()
                    a[j].append(tok)
                    p[j] *= probs[tok]
            
            actions = self.tokenizer.batch_detokenize(torch.tensor(a)[:,1:])
            probs = p/p.sum()

            return namedtuple('Policy', 'actions probs')(actions, probs)
            

class ValueHead(nn.Module):
    def __init__(self, c, d):
        super().__init__()
        self.c = c
        self.d = d
        
        self.l1 = nn.Linear(c, d)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(d, d)
        self.l3 = nn.Linear(d, d)
        self.lf = nn.Linear(d, 1)

    def forward(self, x):
        x = torch.mean(x, axis=1)
        x = self.relu(self.l1(x))
        x = self.relu(self.l2(x))
        x = self.relu(self.l3(x))
        x = self.lf(x)
        return x


## Before implementing heads, read up on Torch head/transformer modules and how they work further. Unclear to me if their transformers do what we want. 
## Also, need to be careful with setting up training vs acting

class AlphaTensor184(nn.Module):
    def __init__(self, s, c, d, elmnt_range, Nsteps, Nsamples, torso_iterations = 8):
        super().__init__()
        self.s = s
        self.c = c
        self.Nlogits = elmnt_range[1]-elmnt_range[0]+1
        self.Nsteps = Nsteps
        self.Nsamples = Nsamples
        
        self.torso = Torso(s, c, torso_iterations)
        self.value_head = ValueHead(c, d) 
        self.policy_head = PolicyHead(Nsteps, elmnt_range, s, c)
    
    def forward(self, x, g=None):
        e = self.torso(x)
        q = self.value_head(e)
        if g == None:  # Inference
            assert(not(self.training))
            policy = self.policy_head(e, Nsamples=self.Nsamples)
            return (q, policy)
        else: # Training
            assert(self.training)
            logits = self.policy_head(e, g=g)
            return (q, logits)

In [34]:
attn_modes = AttentiveModes(4, 64)
print(attn_modes(torch.ones(64,4,4,64), torch.ones(64,4,4,64), torch.ones(64,4,4,64))[0].shape)
torso = Torso(4,64,2)
print(torso(torch.ones(64,4,4,4)).shape)

torch.Size([64, 4, 4, 64])
torch.Size([64, 48, 64])


In [4]:
torch.cuda.empty_cache()

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"Total memory: {t}")
print(f"Reserved memory: {r}")
print(f"Allocated memory: {a}")

alphaTensor184 = AlphaTensor184(s = 3, c = 64, d = 64, elmnt_range=(-2, 2), Nsteps=3, Nsamples=32, torso_iterations=6)

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"Total memory: {t}")
print(f"Reserved memory: {r}")
print(f"Allocated memory: {a}")


pytorch_total_params = sum(p.numel() for p in alphaTensor184.parameters())

print(pytorch_total_params)



Total memory: 8585216000
Reserved memory: 0
Allocated memory: 0
Total memory: 8585216000
Reserved memory: 0
Allocated memory: 0
47038974


In [5]:
alphaTensor184.to("cuda")

t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
print(f"Total memory: {t}")
print(f"Reserved memory: {r}")
print(f"Allocated memory: {a}")


alphaTensor184.train()

Total memory: 8585216000
Reserved memory: 192937984
Allocated memory: 188172800


AlphaTensor184(
  (torso): Torso(
    (l1): Linear(in_features=3, out_features=64, bias=True)
    (attentive_modes): ModuleList(
      (0-5): 6 x AttentiveModes(
        (attention): Attention(
          (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (MAH): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (l1): Linear(in_features=64, out_features=256, bias=True)
          (gelu): GELU(approximate='none')
          (l2): Linear(in_features=256, out_features=64, bias=True)
        )
      )
    )
  )
  (value_head): ValueHead(
    (l1): Linear(in_features=64, out_features=64, bias=True)
    (relu): ReLU()
    (l2): Linear(in_features=64, out_features=64, bias=True)
    (l3): Linear(in_features=64, out_features=64, bias=True)
    (lf): L

In [6]:
print(alphaTensor184)

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(alphaTensor184)



AlphaTensor184(
  (torso): Torso(
    (l1): Linear(in_features=3, out_features=64, bias=True)
    (attentive_modes): ModuleList(
      (0-5): 6 x AttentiveModes(
        (attention): Attention(
          (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (MAH): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (l1): Linear(in_features=64, out_features=256, bias=True)
          (gelu): GELU(approximate='none')
          (l2): Linear(in_features=256, out_features=64, bias=True)
        )
      )
    )
  )
  (value_head): ValueHead(
    (l1): Linear(in_features=64, out_features=64, bias=True)
    (relu): ReLU()
    (l2): Linear(in_features=64, out_features=64, bias=True)
    (l3): Linear(in_features=64, out_features=64, bias=True)
    (lf): L

47038974

In [111]:
vh = PolicyHead(10, 64, 3, 64)

vh.train()
vh.eval()

vh.forward(torch.randn(3 * 3 * 3, 64), Nsamples = 32)

tensor(0.0117, grad_fn=<SelectBackward0>)
tensor(0.0025, grad_fn=<SelectBackward0>)
tensor(0.0162, grad_fn=<SelectBackward0>)
tensor(0.0144, grad_fn=<SelectBackward0>)
tensor(0.0142, grad_fn=<SelectBackward0>)
tensor(0.0059, grad_fn=<SelectBackward0>)
tensor(0.0518, grad_fn=<SelectBackward0>)
tensor(0.0255, grad_fn=<SelectBackward0>)
tensor(0.0087, grad_fn=<SelectBackward0>)
tensor(0.0165, grad_fn=<SelectBackward0>)
tensor(0.0085, grad_fn=<SelectBackward0>)
tensor(0.0104, grad_fn=<SelectBackward0>)
tensor(0.0245, grad_fn=<SelectBackward0>)
tensor(0.0692, grad_fn=<SelectBackward0>)
tensor(0.0134, grad_fn=<SelectBackward0>)
tensor(0.0196, grad_fn=<SelectBackward0>)
tensor(0.0106, grad_fn=<SelectBackward0>)
tensor(0.0098, grad_fn=<SelectBackward0>)
tensor(0.0346, grad_fn=<SelectBackward0>)
tensor(0.0364, grad_fn=<SelectBackward0>)
tensor(0.0501, grad_fn=<SelectBackward0>)
tensor(0.0221, grad_fn=<SelectBackward0>)
tensor(0.0286, grad_fn=<SelectBackward0>)
tensor(0.0161, grad_fn=<SelectBack

(tensor([[ 8, 14, 29, 11, 11, 16, 62, 52, 43, 42],
         [26, 27,  1,  4, 33, 30, 63,  7, 48, 48],
         [10, 60, 19, 42, 31, 10, 39, 28,  1, 11],
         [59, 39, 39, 62, 25, 62,  3, 60, 21,  0],
         [40, 22, 45, 39, 13, 48,  1,  4, 48, 30],
         [ 4, 47, 29, 21,  6, 16, 59, 61, 63, 31],
         [ 3, 62,  3,  1, 41, 12, 22,  6, 52, 52],
         [10, 19, 16, 62, 58, 63, 12,  4, 40, 33],
         [47, 39, 57, 30, 10,  0, 31, 52, 44, 61],
         [ 4, 60,  4,  4, 25,  0, 51, 57, 15, 33],
         [63, 48, 10, 51, 18,  1,  4, 21, 15, 47],
         [54, 10, 39, 50,  3, 54, 48, 30, 61, 54],
         [ 6, 38, 37,  4, 36,  3, 48, 47,  4,  3],
         [19, 38,  7, 38,  9, 31, 48, 31, 47, 29],
         [39, 12, 29, 50, 63, 10,  4, 48, 44, 35],
         [45, 16, 48, 57, 45, 57, 41, 38, 19, 30],
         [55, 44, 51, 58,  4,  3,  4, 33, 21, 62],
         [62,  6,  7, 40, 44, 19,  8, 18, 46, 24],
         [47, 40, 42, 22, 14, 10,  3, 21,  5, 10],
         [10, 45, 27,  4, 30, 6

In [2]:
a = torch.tensor([[1, 0, 0],[0,0,0], [-2, -2, -2]])
b = torch.tensor([[-2,1,0],[0,1,1], [-1, 1, -1]])

In [3]:
x = t.tokenize(a)
y = t.tokenize(b)
print(a, b)
print(x, y)
w = t.detokenize(x)
x = t.detokenize(y)

print(w, x)

tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]]) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]])
tensor([63, 62,  0], dtype=torch.int32) tensor([65, 92, 41], dtype=torch.int32)
tensor([[ 1,  0,  0],
        [ 0,  0,  0],
        [-2, -2, -2]], dtype=torch.int32) tensor([[-2,  1,  0],
        [ 0,  1,  1],
        [-1,  1, -1]], dtype=torch.int32)


In [4]:
def generate_sample_r1(S: int, vals: list[int], factor_dist: list[float]):
    nonzero = False
    while not nonzero:
        t = np.random.choice(vals, size=(3, S), p=factor_dist)
        m = np.tensordot(np.tensordot(t[0, :], t[1, :], axes=0), t[2, :], axes=0)
        assert m.shape == (S, S, S)
        nonzero = np.any(m)
    return t, m

In [5]:
S = 3
vals = [-2, -1, 0, 1, 2]
factor_dist = [0.1, 0.2, 0.4, 0.2, 0.1]

t, m = generate_sample_r1(S, vals, factor_dist)

print(f"tensors {t}")
print(f"result {m}")

tensors [[-1  0  0]
 [ 1  1  0]
 [ 1  0  2]]
result [[[-1  0 -2]
  [-1  0 -2]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]]


In [12]:
def main(S: int, r_limit: int, factor_dist: dict, N: int, seed: int = None):
    if seed is not None:
        random.seed(seed)

    low, high = min(factor_dist.keys()), max(factor_dist.keys())
    vals = list(factor_dist.keys())
    dist = [factor_dist[i] for i in factor_dist.keys()]

    
    tokenizer = utilities.Tokenizer(range=(low, high))

    SAR_pairs = []
    
    for i in range(N):
        R = random.randint(1, r_limit)
        T = torch.zeros((S, S, S), dtype=torch.int)
        reward = 0
        for j in range(R):
            sample, m = generate_sample_r1(S, vals, dist)
            T += torch.from_numpy(m)
            tokens = tokenizer.tokenize(torch.from_numpy(sample.T))
            reward += -1
            SAR_pairs.append((T, tokens, reward))

    return SAR_pairs
            

In [15]:
torch.load("data/Sar_pairs_3_1000_123456.pt")

[(tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([61, 62, 82], dtype=torch.int32),
  -1),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 63, 37], dtype=torch.int32),
  -2),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
          [[-2,  1,  1],
           [ 2,  1, -1],
           [ 0,  0,  0]],
  
          [[-1,  2, -1],
           [ 2,  0,  0],
           [-1,  0,  0]]], dtype=torch.int32),
  tensor([82, 62, 63], dtype=torch.int32),
  -3),
 (tensor([[[ 0, -4,  0],
           [-3, -1,  0],
           [ 0, -1, -1]],
  
      

In [14]:
print(main(3, 10, {-2: .1, -1 : .2, 0: 0.4, 1: 0.2, 2: 0.1}, 10, seed=0))

[(tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([109,  62,  28], dtype=torch.int32), -1), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([38, 58, 56], dtype=torch.int32), -2), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   6]],

        [[ -4,   0,   8],
         [  1,   0,  -8],
         [  8,  -1,   6]],

        [[ -2,   2,   1],
         [ -1,   0,   0],
         [ -9,   4,   6]]], dtype=torch.int32), tensor([93, 37, 65], dtype=torch.int32), -3), (tensor([[[ -1,  -1,   6],
         [  1,   0,  -4],
         [-11,  -2,   

In [2]:
t.batch_detokenize(torch.tensor([[0,2,1]]))

tensor([[[-2, -2, -2],
         [ 0, -2, -2],
         [-1, -2, -2]]], dtype=torch.int32)

In [6]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import itertools

class ActionDataset(Dataset):
    def __init__(self, pregen_files, max_pregen, max_selfplay, selfplay_files = None):
        self.max_pregen = max_pregen
        self.max_selfplay = max_selfplay
        l = [torch.load(file) for file in pregen_files]
        self.pregen_actions = list(itertools.chain.from_iterable(l))[:self.max_pregen]
        self.selfplay_actions = []
        if selfplay_files != None:
            l = [torch.load(file) for file in selfplay_files]
            self.selfplay_actions = list(itertools.chain.from_iterable(l))[:self.max_selfplay]

    def __len__(self):
        return len(self.pregen_actions) + len(self.selfplay_actions)

    def __getitem__(self, idx):
        if idx < self.max_pregen:
            return self.pregen_actions[idx]
        else:
            return self.selfplay_actions[idx - self.max_pregen]
        
    def add_selfplay_actions(self, actions):
        self.selfplay_actions = self.selfplay_actions + actions
        if len(self.selfplay_actions) > self.max_selfplay:
            self.selfplay_actions = self.selfplay_actions[len(self.selfplay_actions) - self.max_selfplay:]

dataset = ActionDataset(["data/Sar_pairs_3_1000_123456.pt"], 40000, 10000)

len(dataset.pregen_actions)


40000

In [27]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for x in dataloader: print(x)

[tensor([[[[-1,  0,  3],
          [ 1, -5, -1],
          [-2, -2,  3]],

         [[-2, -1,  5],
          [ 2,  1,  0],
          [-1,  1, -2]],

         [[ 0,  2, -1],
          [-6,  1,  3],
          [-4, -2,  0]]],


        [[[-2,  0, -3],
          [ 2,  1,  1],
          [-1, -3,  0]],

         [[-5,  0,  0],
          [ 3,  0,  5],
          [-1,  0, -3]],

         [[-2, -1,  0],
          [ 4,  0,  2],
          [ 4,  2, -1]]],


        [[[ 2,  0, -3],
          [-1, -1, -2],
          [ 0, -1,  0]],

         [[ 1, -1,  1],
          [ 2,  0, -1],
          [ 1,  0, -1]],

         [[ 0,  2,  1],
          [ 0, -2,  2],
          [-2,  1,  1]]],


        [[[-2, -2,  1],
          [ 0, -1, -3],
          [-2, -1,  1]],

         [[ 1,  2, -1],
          [ 0,  2,  1],
          [-3,  4,  3]],

         [[ 0, -2, -2],
          [-1,  4,  0],
          [ 1, -1, -1]]],


        [[[ 0, -4,  0],
          [-3, -1,  0],
          [ 0, -1, -1]],

         [[-2,  1,  1],
     

In [26]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from architecture import *
from tqdm import tqdm

def loss_fn(pred_logits, true_tokens, pred_value, true_value, val_weight=1.0, device = 'cuda'):
    policy_loss = nn.functional.cross_entropy(pred_logits.reshape(-1, pred_logits.shape[-1]), true_tokens.flatten().type(torch.LongTensor).to(device))
    value_loss = ((pred_value - true_value)**2).mean()
    return policy_loss + val_weight*value_loss

def train(model, dataset, epochs, batch_size = 1024, lr=0.001, device = 'cuda'):
    model.train()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        running_loss = 0.0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()
            states, actions, values = batch
            states = states.to(device).float()
            actions = actions.to(device)
            values = values.to(device).float()
            pred_value, pred_logits = model(states, g=actions)
            loss = loss_fn(pred_logits, actions, pred_value, values)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')

In [27]:
alphaTensor184.to('cuda')   
train(alphaTensor184, dataset, 10, batch_size=1024, lr=0.01)

  0%|          | 0/40 [00:00<?, ?it/s]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  2%|▎         | 1/40 [00:05<03:50,  5.90s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  5%|▌         | 2/40 [00:10<03:12,  5.06s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  8%|▊         | 3/40 [00:14<02:55,  4.75s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 10%|█         | 4/40 [00:19<02:46,  4.63s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 12%|█▎        | 5/40 [00:23<02:39,  4.55s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 15%|█▌        | 6/40 [00:28<02:33,  4.51s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 18%|█▊        | 7/40 [00:32<02:28,  4.49s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [00:36<02:23,  4.48s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 22%|██▎       | 9/40 [00:41<02:17,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 25%|██▌       | 10/40 [00:45<02:13,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 28%|██▊       | 11/40 [00:50<02:08,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 30%|███       | 12/40 [00:54<02:03,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 32%|███▎      | 13/40 [00:58<01:56,  4.33s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 35%|███▌      | 14/40 [01:02<01:50,  4.26s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 38%|███▊      | 15/40 [01:06<01:44,  4.20s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 40%|████      | 16/40 [01:10<01:39,  4.16s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 42%|████▎     | 17/40 [01:14<01:35,  4.15s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 45%|████▌     | 18/40 [01:19<01:31,  4.17s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 48%|████▊     | 19/40 [01:23<01:27,  4.16s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 50%|█████     | 20/40 [01:27<01:24,  4.24s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 52%|█████▎    | 21/40 [01:32<01:21,  4.30s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 55%|█████▌    | 22/40 [01:36<01:17,  4.33s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 57%|█████▊    | 23/40 [01:40<01:13,  4.31s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 60%|██████    | 24/40 [01:45<01:08,  4.28s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 62%|██████▎   | 25/40 [01:49<01:03,  4.25s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 65%|██████▌   | 26/40 [01:53<00:59,  4.24s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 68%|██████▊   | 27/40 [01:57<00:54,  4.23s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 70%|███████   | 28/40 [02:01<00:50,  4.22s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 72%|███████▎  | 29/40 [02:06<00:46,  4.22s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 75%|███████▌  | 30/40 [02:10<00:42,  4.21s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 78%|███████▊  | 31/40 [02:14<00:37,  4.21s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 80%|████████  | 32/40 [02:19<00:34,  4.31s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 82%|████████▎ | 33/40 [02:23<00:30,  4.39s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 85%|████████▌ | 34/40 [02:28<00:26,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 88%|████████▊ | 35/40 [02:32<00:22,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 90%|█████████ | 36/40 [02:37<00:17,  4.46s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 92%|█████████▎| 37/40 [02:41<00:13,  4.47s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 95%|█████████▌| 38/40 [02:46<00:08,  4.46s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 98%|█████████▊| 39/40 [02:50<00:04,  4.46s/it]

a = torch.Size([64, 3])
e = torch.Size([64, 27, 64])
x = torch.Size([64, 3, 1024])
x = torch.Size([64, 3, 1024])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])


100%|██████████| 40/40 [02:51<00:00,  4.28s/it]


Epoch 1, Loss: 547.7125366210937


  0%|          | 0/40 [00:00<?, ?it/s]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  2%|▎         | 1/40 [00:04<02:51,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  5%|▌         | 2/40 [00:08<02:47,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  8%|▊         | 3/40 [00:13<02:39,  4.32s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 10%|█         | 4/40 [00:17<02:33,  4.27s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 12%|█▎        | 5/40 [00:21<02:28,  4.25s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 15%|█▌        | 6/40 [00:25<02:23,  4.23s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 18%|█▊        | 7/40 [00:29<02:19,  4.22s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [00:34<02:14,  4.22s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 22%|██▎       | 9/40 [00:38<02:10,  4.21s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 25%|██▌       | 10/40 [00:42<02:06,  4.21s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 28%|██▊       | 11/40 [00:46<02:01,  4.21s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 30%|███       | 12/40 [00:50<01:58,  4.22s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 32%|███▎      | 13/40 [00:55<01:57,  4.34s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 35%|███▌      | 14/40 [01:00<01:54,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 38%|███▊      | 15/40 [01:04<01:51,  4.46s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 40%|████      | 16/40 [01:09<01:46,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 42%|████▎     | 17/40 [01:13<01:43,  4.48s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 45%|████▌     | 18/40 [01:18<01:40,  4.55s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 48%|████▊     | 19/40 [01:22<01:35,  4.56s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 50%|█████     | 20/40 [01:27<01:30,  4.54s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 52%|█████▎    | 21/40 [01:31<01:25,  4.52s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 55%|█████▌    | 22/40 [01:36<01:21,  4.51s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 57%|█████▊    | 23/40 [01:40<01:16,  4.49s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 60%|██████    | 24/40 [01:45<01:11,  4.48s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 62%|██████▎   | 25/40 [01:49<01:05,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 65%|██████▌   | 26/40 [01:53<01:00,  4.33s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 68%|██████▊   | 27/40 [01:57<00:55,  4.25s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 70%|███████   | 28/40 [02:01<00:50,  4.21s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 72%|███████▎  | 29/40 [02:05<00:45,  4.17s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 75%|███████▌  | 30/40 [02:09<00:41,  4.14s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 78%|███████▊  | 31/40 [02:14<00:37,  4.12s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 80%|████████  | 32/40 [02:18<00:32,  4.12s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 82%|████████▎ | 33/40 [02:22<00:28,  4.12s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 85%|████████▌ | 34/40 [02:26<00:24,  4.14s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 88%|████████▊ | 35/40 [02:31<00:21,  4.30s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 90%|█████████ | 36/40 [02:35<00:17,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 92%|█████████▎| 37/40 [02:40<00:13,  4.46s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 95%|█████████▌| 38/40 [02:44<00:08,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 98%|█████████▊| 39/40 [02:49<00:04,  4.43s/it]

a = torch.Size([64, 3])
e = torch.Size([64, 27, 64])
x = torch.Size([64, 3, 1024])
x = torch.Size([64, 3, 1024])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])


100%|██████████| 40/40 [02:49<00:00,  4.24s/it]


Epoch 2, Loss: 366.26778106689454


  0%|          | 0/40 [00:00<?, ?it/s]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  2%|▎         | 1/40 [00:04<02:51,  4.39s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  5%|▌         | 2/40 [00:08<02:46,  4.37s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  8%|▊         | 3/40 [00:13<02:42,  4.38s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 10%|█         | 4/40 [00:17<02:38,  4.39s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 12%|█▎        | 5/40 [00:22<02:35,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 15%|█▌        | 6/40 [00:26<02:30,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 18%|█▊        | 7/40 [00:30<02:26,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [00:35<02:21,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 22%|██▎       | 9/40 [00:39<02:17,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 25%|██▌       | 10/40 [00:44<02:13,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 28%|██▊       | 11/40 [00:48<02:08,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 30%|███       | 12/40 [00:53<02:04,  4.46s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 32%|███▎      | 13/40 [00:57<02:00,  4.47s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 35%|███▌      | 14/40 [01:02<01:55,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 38%|███▊      | 15/40 [01:06<01:50,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 40%|████      | 16/40 [01:10<01:46,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 42%|████▎     | 17/40 [01:15<01:41,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 45%|████▌     | 18/40 [01:19<01:37,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 48%|████▊     | 19/40 [01:24<01:32,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 50%|█████     | 20/40 [01:28<01:28,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 52%|█████▎    | 21/40 [01:32<01:24,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 55%|█████▌    | 22/40 [01:37<01:19,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 57%|█████▊    | 23/40 [01:41<01:15,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 60%|██████    | 24/40 [01:46<01:10,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 62%|██████▎   | 25/40 [01:50<01:06,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 65%|██████▌   | 26/40 [01:55<01:01,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 68%|██████▊   | 27/40 [01:59<00:57,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 70%|███████   | 28/40 [02:03<00:53,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 72%|███████▎  | 29/40 [02:08<00:48,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 75%|███████▌  | 30/40 [02:12<00:44,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 78%|███████▊  | 31/40 [02:17<00:39,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 80%|████████  | 32/40 [02:21<00:35,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 82%|████████▎ | 33/40 [02:25<00:30,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 85%|████████▌ | 34/40 [02:30<00:26,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 88%|████████▊ | 35/40 [02:34<00:22,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 90%|█████████ | 36/40 [02:39<00:17,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 92%|█████████▎| 37/40 [02:43<00:13,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 95%|█████████▌| 38/40 [02:48<00:08,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 98%|█████████▊| 39/40 [02:52<00:04,  4.41s/it]

a = torch.Size([64, 3])
e = torch.Size([64, 27, 64])
x = torch.Size([64, 3, 1024])
x = torch.Size([64, 3, 1024])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])


100%|██████████| 40/40 [02:52<00:00,  4.32s/it]


Epoch 3, Loss: 366.06814727783205


  0%|          | 0/40 [00:00<?, ?it/s]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  2%|▎         | 1/40 [00:04<02:54,  4.47s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  5%|▌         | 2/40 [00:08<02:49,  4.47s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  8%|▊         | 3/40 [00:13<02:44,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 10%|█         | 4/40 [00:17<02:39,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 12%|█▎        | 5/40 [00:22<02:35,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 15%|█▌        | 6/40 [00:26<02:31,  4.44s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 18%|█▊        | 7/40 [00:31<02:26,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [00:35<02:21,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 22%|██▎       | 9/40 [00:39<02:17,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 25%|██▌       | 10/40 [00:44<02:12,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 28%|██▊       | 11/40 [00:48<02:08,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 30%|███       | 12/40 [00:53<02:03,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 32%|███▎      | 13/40 [00:57<01:59,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 35%|███▌      | 14/40 [01:01<01:54,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 38%|███▊      | 15/40 [01:06<01:50,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 40%|████      | 16/40 [01:10<01:45,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 42%|████▎     | 17/40 [01:15<01:41,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 45%|████▌     | 18/40 [01:19<01:37,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 48%|████▊     | 19/40 [01:23<01:32,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 50%|█████     | 20/40 [01:28<01:28,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 52%|█████▎    | 21/40 [01:32<01:23,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 55%|█████▌    | 22/40 [01:37<01:19,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 57%|█████▊    | 23/40 [01:41<01:14,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 60%|██████    | 24/40 [01:46<01:10,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 62%|██████▎   | 25/40 [01:50<01:06,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 65%|██████▌   | 26/40 [01:54<01:01,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 68%|██████▊   | 27/40 [01:59<00:57,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 70%|███████   | 28/40 [02:03<00:52,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 72%|███████▎  | 29/40 [02:08<00:48,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 75%|███████▌  | 30/40 [02:12<00:44,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 78%|███████▊  | 31/40 [02:16<00:39,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 80%|████████  | 32/40 [02:21<00:35,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 82%|████████▎ | 33/40 [02:25<00:30,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 85%|████████▌ | 34/40 [02:30<00:26,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 88%|████████▊ | 35/40 [02:34<00:22,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 90%|█████████ | 36/40 [02:38<00:17,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 92%|█████████▎| 37/40 [02:43<00:13,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 95%|█████████▌| 38/40 [02:47<00:08,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 98%|█████████▊| 39/40 [02:52<00:04,  4.41s/it]

a = torch.Size([64, 3])
e = torch.Size([64, 27, 64])
x = torch.Size([64, 3, 1024])
x = torch.Size([64, 3, 1024])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])


100%|██████████| 40/40 [02:52<00:00,  4.31s/it]


Epoch 4, Loss: 362.3225860595703


  0%|          | 0/40 [00:00<?, ?it/s]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  2%|▎         | 1/40 [00:04<02:52,  4.42s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  5%|▌         | 2/40 [00:08<02:47,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  8%|▊         | 3/40 [00:13<02:43,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 10%|█         | 4/40 [00:17<02:38,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 12%|█▎        | 5/40 [00:22<02:34,  4.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 15%|█▌        | 6/40 [00:26<02:29,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 18%|█▊        | 7/40 [00:30<02:25,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [00:35<02:21,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 22%|██▎       | 9/40 [00:39<02:17,  4.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 25%|██▌       | 10/40 [00:44<02:13,  4.45s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 28%|██▊       | 11/40 [00:48<02:07,  4.41s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 30%|███       | 12/40 [00:53<02:11,  4.69s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 32%|███▎      | 13/40 [00:59<02:15,  5.01s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 35%|███▌      | 14/40 [01:05<02:16,  5.23s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 38%|███▊      | 15/40 [01:11<02:14,  5.40s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 40%|████      | 16/40 [01:16<02:12,  5.51s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 42%|████▎     | 17/40 [01:22<02:09,  5.62s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 45%|████▌     | 18/40 [01:28<02:04,  5.67s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 48%|████▊     | 19/40 [01:35<02:09,  6.18s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 50%|█████     | 20/40 [01:43<02:10,  6.52s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 52%|█████▎    | 21/40 [01:50<02:05,  6.59s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 55%|█████▌    | 22/40 [01:56<02:00,  6.68s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 57%|█████▊    | 23/40 [02:05<02:00,  7.11s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 60%|██████    | 24/40 [02:11<01:51,  6.95s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 62%|██████▎   | 25/40 [02:18<01:44,  6.97s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 65%|██████▌   | 26/40 [02:25<01:35,  6.85s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 68%|██████▊   | 27/40 [02:31<01:27,  6.71s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 70%|███████   | 28/40 [02:37<01:19,  6.60s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 72%|███████▎  | 29/40 [02:46<01:18,  7.15s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 75%|███████▌  | 30/40 [02:52<01:09,  6.90s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 78%|███████▊  | 31/40 [02:58<01:00,  6.71s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 80%|████████  | 32/40 [03:05<00:52,  6.60s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 82%|████████▎ | 33/40 [03:11<00:45,  6.53s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 85%|████████▌ | 34/40 [03:18<00:38,  6.47s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 88%|████████▊ | 35/40 [03:24<00:32,  6.43s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 90%|█████████ | 36/40 [03:30<00:25,  6.39s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 92%|█████████▎| 37/40 [03:36<00:19,  6.35s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 95%|█████████▌| 38/40 [03:43<00:12,  6.32s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 98%|█████████▊| 39/40 [03:50<00:06,  6.73s/it]

a = torch.Size([64, 3])
e = torch.Size([64, 27, 64])
x = torch.Size([64, 3, 1024])
x = torch.Size([64, 3, 1024])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])
c = torch.Size([64, 3, 1024])
e = torch.Size([64, 27, 64])


100%|██████████| 40/40 [03:51<00:00,  5.79s/it]


Epoch 5, Loss: 363.57884521484374


  0%|          | 0/40 [00:00<?, ?it/s]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  2%|▎         | 1/40 [00:06<04:29,  6.91s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  5%|▌         | 2/40 [00:13<04:26,  7.01s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


  8%|▊         | 3/40 [00:21<04:22,  7.09s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 10%|█         | 4/40 [00:28<04:13,  7.03s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 12%|█▎        | 5/40 [00:35<04:06,  7.03s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 15%|█▌        | 6/40 [00:42<03:59,  7.05s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 18%|█▊        | 7/40 [00:49<03:51,  7.02s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [00:56<03:44,  7.03s/it]

a = torch.Size([1024, 3])
e = torch.Size([1024, 27, 64])
x = torch.Size([1024, 3, 1024])
x = torch.Size([1024, 3, 1024])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])
c = torch.Size([1024, 3, 1024])
e = torch.Size([1024, 27, 64])


 20%|██        | 8/40 [01:01<04:07,  7.74s/it]


KeyboardInterrupt: 