<a href="https://colab.research.google.com/github/Uzmamushtaque/Rec_transform/blob/main/BERT4NIP_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import itertools
# from box import Box

import warnings
warnings.filterwarnings('ignore')

In [2]:
config = {
    'data_path' : '/content/ratings.csv',
    'max_len' : 64,
    'hidden_units' : 256, # Embedding size
    'num_heads' : 2, # Multi-head layer
    'num_layers': 2, # block (encoder layer)
    'dropout_rate' : 0.1, # dropout
    'lr' : 0.001,
    'batch_size' : 32,
    'num_epochs' : 4,
    'num_workers' : 2,
    'mask_prob' : 0.05, # for cloze task
    'time_seq' : 500, # time limit for one sequence
    'test_size' : 0.33

}

MAX_LEN = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [3]:
df = pd.read_csv(config['data_path'])

In [4]:
df.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,4.0,964982703
1,1,3,4.0,964981247
2,1,6,4.0,964982224
3,1,47,5.0,964983815
4,1,50,5.0,964982931


In [5]:
class MakeSequenceDataSet():
    """
    SequenceData
    """
    def __init__(self, config):
        #self.df = pd.read_csv(os.path.join(config['data_path'], 'rating.csv'))
        self.df = pd.read_csv(config['data_path'])
        self.item_encoder, self.item_decoder = self.generate_encoder_decoder('movieId')
        self.user_encoder, self.user_decoder = self.generate_encoder_decoder('userId')
        self.num_item, self.num_user = len(self.item_encoder), len(self.user_encoder)

        self.df['item_idx'] = self.df['movieId'].apply(lambda x : self.item_encoder[x] + 1)
        self.df['user_idx'] = self.df['userId'].apply(lambda x : self.user_encoder[x])
        self.df = self.df.sort_values(['user_idx', 'timestamp'])
        self.final_df = pd.DataFrame(self.get_sequence().items(),columns=['user_idx', 'items'])

    def num_item(self):
      return self.num_item

    def generate_encoder_decoder(self, col : str) -> dict:
        """
        encoder, decoder

        Args:
            col (str):  columns
        Returns:
            dict: user encoder, decoder
        """

        encoder = {}
        decoder = {}
        ids = self.df[col].unique()

        for idx, _id in enumerate(ids):
            encoder[_id] = idx
            decoder[idx] = _id

        return encoder, decoder

    def get_sequence(self):
      users = defaultdict(list)
      for user, item in zip(self.df['user_idx'], self.df['item_idx']):
          users[user].append(item)
      return users




In [45]:
s = MakeSequenceDataSet(config)

In [51]:
#s.final_df.iloc[random.randint(0,len(s.final_df)),1]
s.final_df.iloc[1][1]

[233,
 245,
 259,
 257,
 252,
 255,
 220,
 250,
 243,
 251,
 244,
 235,
 236,
 253,
 242,
 240,
 237,
 19,
 238,
 239,
 254,
 241,
 248,
 247,
 249,
 234,
 258,
 256,
 246]

In [8]:
class BERTDataset(Dataset):
  def __init__(self, data, max_len,seq_len,mask_prob) -> None:
     super().__init__()
     self.data = data
     self.max_len = max_len
     self.seq_len = seq_len
     self.mask_prob = mask_prob
     temp = max(self.data.item_encoder.values())
     self.special = {'[CLS]': temp+1,
                     '[SEP]': temp+2,
                     '[MASK]': temp+3,
                     '[PAD]': 0}


  def __len__(self):
    return len(self.data.final_df)

  def __getitem__(self, index):
    # Step 1: get random sentence pair, either negative or positive (saved as is_next_label)
    t1, t2, is_next_label = self.get_sent(index)
    # Step 2: replace random words in sentence with mask / random words
    t1_random, t1_label = self.random_word(t1)
    t2_random, t2_label = self.random_word(t2)
    # Step 3: Adding CLS and SEP tokens to the start and end of sentences
    # Adding PAD token for labels
    t1 = [self.special['[CLS]']] + t1_random + [self.special['[SEP]']]
    t2 = t2_random + [self.special['[SEP]']]
    t1_label = [self.special['[PAD]']] + t1_label + [self.special['[PAD]']]
    t2_label = t2_label + [self.special['[PAD]']]
    # Step 4: combine sentence 1 and 2 as one input
    # adding PAD tokens to make the sentence same length as seq_len
    segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
    bert_input = (t1 + t2)[:self.seq_len]
    bert_label = (t1_label + t2_label)[:self.seq_len]
    padding = [self.special['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
    bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

    output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

    return {key: torch.tensor(value) for key, value in output.items()}


  def random_word(self,L):
    tokens = L
    output_label = []
    output = []
    # 15% of the tokens would be replaced
    for i, token in enumerate(tokens):
        prob = random.random()

        if prob < 0.15:
            prob /= 0.15

            # 80% chance change token to mask token
            if prob < 0.8:
                output.append(self.special['[MASK]'])

            # 10% chance change token to random token
            elif prob < 0.9:

                output.append(random.randrange(self.data.num_item))

            # 10% chance change token to current token
            else:
                output.append(token)

            output_label.append(token)



    # flattening
    output = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output]))
    output_label = list(itertools.chain(*[[x] if not isinstance(x, list) else x for x in output_label]))
    assert len(output) == len(output_label)
    return output, output_label




  def get_sent(self,index):
    t1 ,t2 = self.get_sentence_pair(self.max_len,index)
    if random.random() > 0.5:
      return t1,t2,1
    else:
      return t1,self.get_random_line()[:self.max_len],0



  def get_sentence_pair(self,seq_len,index):
    l = self.data.final_df.iloc[index][1]
    if len(l) > seq_len:
      return l[:seq_len], l[seq_len:2*seq_len]
    else:
      return l, l

  def get_random_line(self):
    return self.data.final_df.iloc[random.randint(0,len(self.data.final_df)),1]



In [19]:
t= BERTDataset(s, 10, 64,config['mask_prob'])

In [28]:
train_data = BERTDataset(s,20, 64,config['mask_prob'] )
train_loader = DataLoader(
   train_data, batch_size=32, shuffle=True, pin_memory=True)


In [44]:
#sample_data = next(iter(train_loader))
print(train_data[random.randrange(len(train_data))])

{'bert_input': tensor([9724, 9726,   26, 9725, 9726, 9726, 9725,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0]), 'bert_label': tensor([   0,   17,   26,    0, 3679,   26,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0]), 'segment_label': tensor([1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
      

In [19]:
s.num_item+3

9727

In [30]:
class PositionalEmbedding(torch.nn.Module):

    def __init__(self, d_model, max_len=128):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        for pos in range(max_len):
            # for each dimension of the each position
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

        # include the batch size
        self.pe = pe.unsqueeze(0)
        # self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe

class BERTEmbedding(torch.nn.Module):
    """
    BERT Embedding which is consisted with under features
        1. TokenEmbedding : normal embedding matrix
        2. PositionalEmbedding : adding positional information using sin, cos
        2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)
        sum of all these features are output of BERTEmbedding
    """

    def __init__(self, vocab_size, embed_size, seq_len=64, dropout=0.1):
        """
        :param vocab_size: total vocab size
        :param embed_size: embedding size of token embedding
        :param dropout: dropout rate
        """

        super().__init__()
        self.embed_size = embed_size
        # (m, seq_len) --> (m, seq_len, embed_size)
        # padding_idx is not updated during training, remains as fixed pad (0)
        self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0)
        self.position = PositionalEmbedding(d_model=embed_size, max_len=seq_len)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)


In [31]:
b = BERTEmbedding(s.num_item+3, 768, seq_len=64, dropout=0.1)

In [32]:
l = train_data[random.randrange(len(train_data))]

In [33]:
x = b.forward(l['bert_input'], l['segment_label'])

In [34]:
x.shape

torch.Size([1, 64, 768])

In [35]:
class MultiHeadAttention(torch.nn.Module):
  def __init__(self,heads,d_model,dropout = 0.1):
    super().__init__()
    assert d_model % heads == 0
    self.d_k = d_model // heads
    self.heads = heads
    self.dropout = torch.nn.Dropout(dropout)
    self.query = torch.nn.Linear(d_model, d_model)
    self.key = torch.nn.Linear(d_model, d_model)
    self.value = torch.nn.Linear(d_model, d_model)
    self.output_linear = torch.nn.Linear(d_model, d_model)

  def forward(self, query, key, value, mask = None):
    # (batch_size, max_len, d_model)
    query = self.query(query.float())
    key = self.key(key.float())
    value = self.value(value.float())
    # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
    query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
    key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
    value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
    # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
    scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))
    scores = scores.masked_fill(mask == 0, -1e9)
    # max_len X max_len matrix of attention
    weights = F.softmax(scores, dim=-1)
    weights = self.dropout(weights)
    #print(weights.dtype)

    # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
    context = torch.matmul(weights, value)

    # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
    context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

    # (batch_size, max_len, d_model)
    return self.output_linear(context)


In [36]:
l = train_data[random.randrange(len(train_data))]

In [37]:
t = MultiHeadAttention(4,768)

In [58]:
class FeedForward(torch.nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()

        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(torch.nn.Module):
    def __init__(
        self,
        d_model=768,
        heads=12,
        feed_forward_hidden=768 * 4,
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model,dropout)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [59]:
class BERT(torch.nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4 * hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = d_model * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=d_model)

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # (batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)
        return x

class NextSentencePrediction(torch.nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, 2)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        return self.softmax(self.linear(x[:, 0]))

class MaskedLanguageModel(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

class BERTLM(torch.nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.d_model)
        self.mask_lm = MaskedLanguageModel(self.bert.d_model, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)

In [60]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr


In [61]:
class BERTTrainer:
    def __init__(
        self,
        model,
        train_dataloader,
        test_dataloader=None,
        lr= 1e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=10000,
        log_freq=10,
        device='cuda'
        ):

        self.device = device
        self.model = model
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
            )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, data in data_iter:

            # 0. batch_data will be sent into the device(GPU or cpu)
            data = {key: value.to(self.device) for key, value in data.items()}
            # Check if any token index is out of range
            if (data["bert_input"] >= self.model.bert.embedding.token.num_embeddings).any():
                print("Error: Token indices out of range encountered.")
                # Further debugging, e.g., print the offending indices
                invalid_indices = data["bert_input"][data["bert_input"] >= self.model.bert.embedding.token.num_embeddings]
                print("Invalid indices:", invalid_indices)
                break  # Stop the training process

            # 1. forward the next_sentence_prediction and masked_lm model
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        )

In [62]:
'''test run'''
import tqdm
train_data = BERTDataset(
   s,10, 64,config['mask_prob'])

train_loader = DataLoader(
   train_data, batch_size=32, shuffle=True, pin_memory=True)

bert_model = BERT(
  vocab_size=s.num_item+3,
  d_model=768,
  n_layers=2,
  heads=12,
  dropout=0.1
)

bert_lm = BERTLM(bert_model, s.num_item+3)
bert_trainer = BERTTrainer(bert_lm, train_loader, device='cpu')


Total Parameters: 29126913


In [63]:
epochs = 2

for epoch in range(epochs):
   bert_trainer.train(epoch)

EP_train:0:   0%|| 0/20 [00:00<?, ?it/s]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:   5%|| 1/20 [00:07<02:31,  8.00s/it]

{'epoch': 0, 'iter': 0, 'avg_loss': 9.793241500854492, 'avg_acc': 56.25, 'loss': 9.793241500854492}
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  10%|| 2/20 [00:13<02:00,  6.71s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  15%|| 3/20 [00:20<01:53,  6.68s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  20%|| 4/20 [00:25<01:37,  6.08s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  25%|| 5/20 [00:30<01:26,  5.77s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  30%|| 6/20 [00:37<01:22,  5.91s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  35%|| 7/20 [00:42<01:13,  5.66s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  40%|| 8/20 [00:48<01:10,  5.89s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  45%|| 9/20 [00:53<01:02,  5.65s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  50%|| 10/20 [00:59<00:58,  5.83s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  55%|| 11/20 [01:05<00:51,  5.70s/it]

{'epoch': 0, 'iter': 10, 'avg_loss': 9.78149075941606, 'avg_acc': 54.26136363636363, 'loss': 9.720915794372559}
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  60%|| 12/20 [01:10<00:44,  5.51s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  65%|| 13/20 [01:17<00:42,  6.06s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  70%|| 14/20 [01:22<00:34,  5.78s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  75%|| 15/20 [01:29<00:29,  5.98s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  80%|| 16/20 [01:34<00:22,  5.71s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  85%|| 17/20 [01:40<00:17,  5.80s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  90%|| 18/20 [01:45<00:11,  5.70s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:0:  95%|| 19/20 [01:50<00:05,  5.51s/it]

torch.Size([2, 12, 64, 64]) torch.Size([2, 12, 64, 64])
torch.Size([2, 12, 64, 64]) torch.Size([2, 12, 64, 64])


EP_train:0: 100%|| 20/20 [01:51<00:00,  5.57s/it]


EP0, train:             avg_loss=9.749251079559325,             total_acc=52.950819672131146


EP_train:1:   0%|| 0/20 [00:00<?, ?it/s]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:   5%|| 1/20 [00:06<02:00,  6.36s/it]

{'epoch': 1, 'iter': 0, 'avg_loss': 9.565668106079102, 'avg_acc': 34.375, 'loss': 9.565668106079102}
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  10%|| 2/20 [00:11<01:42,  5.70s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  15%|| 3/20 [00:18<01:42,  6.03s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  20%|| 4/20 [00:23<01:30,  5.67s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  25%|| 5/20 [00:29<01:26,  5.77s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  30%|| 6/20 [00:34<01:19,  5.71s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  35%|| 7/20 [00:39<01:11,  5.52s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  40%|| 8/20 [00:46<01:09,  5.81s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  45%|| 9/20 [00:51<01:01,  5.58s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  50%|| 10/20 [00:57<00:58,  5.88s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  55%|| 11/20 [01:02<00:50,  5.65s/it]

{'epoch': 1, 'iter': 10, 'avg_loss': 9.643505269830877, 'avg_acc': 47.44318181818182, 'loss': 9.597225189208984}
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  60%|| 12/20 [01:08<00:44,  5.59s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  65%|| 13/20 [01:14<00:40,  5.82s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  70%|| 14/20 [01:19<00:33,  5.63s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  75%|| 15/20 [01:27<00:31,  6.30s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  80%|| 16/20 [01:32<00:23,  5.95s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  85%|| 17/20 [01:39<00:18,  6.19s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  90%|| 18/20 [01:44<00:11,  5.90s/it]

torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])
torch.Size([32, 12, 64, 64]) torch.Size([32, 12, 64, 64])


EP_train:1:  95%|| 19/20 [01:50<00:05,  5.86s/it]

torch.Size([2, 12, 64, 64]) torch.Size([2, 12, 64, 64])
torch.Size([2, 12, 64, 64]) torch.Size([2, 12, 64, 64])


EP_train:1: 100%|| 20/20 [01:51<00:00,  5.57s/it]

EP1, train:             avg_loss=9.58076901435852,             total_acc=48.19672131147541



