In [46]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
import os

In [47]:
gpu_avail = torch.cuda.is_available()
print(f"Is the GPU available? {gpu_avail}")

Is the GPU available? True


In [48]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)

Device cuda


In [49]:
class CustomDataset(Dataset):
    def __init__(self, path):
        with open(f"{path}.sources") as f:
            self.sources = f.readlines()
            self.vocab_src = set()
            for lines in self.sources:
                self.vocab_src.update(lines[1:-1])
                
        with open(f"{path}.targets") as f:
            self.targets = f.readlines()
            self.vocab_tgt = set()
            for lines in self.targets:
                self.vocab_tgt.update(lines)
            

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx):
        return self.sources[idx][1:-1], self.targets[idx]
        

In [50]:
vocab_src = set()
vocab_tgt = set()
train_dataset = CustomDataset('Data/A3 files/train')
eval_dataset = CustomDataset('Data/A3 files/dev')
test_dataset = CustomDataset('Data/A3 files/test')
vocab_src.update(train_dataset.vocab_src)
vocab_src.update(eval_dataset.vocab_src)
vocab_src.update(test_dataset.vocab_src)
vocab_tgt.update(train_dataset.vocab_tgt)
vocab_tgt.update(eval_dataset.vocab_tgt)
vocab_tgt.update(test_dataset.vocab_tgt)
vocab_src = list(vocab_src)
vocab_tgt = list(vocab_tgt)
temp = {}
for ind, key in enumerate(vocab_src):
    temp[key] = ind
vocab_src = temp
temp = {}
for ind, key in enumerate(vocab_tgt):
    temp[key] = ind
vocab_tgt = temp
vocab_src["END"] = 84
vocab_src["PAD"] = 85
vocab_tgt["STR"] = 44
vocab_tgt["END"] = 45
vocab_tgt["PAD"] = 46
reverse_vocab_tgt = {}
for key in vocab_tgt:
    val = vocab_tgt[key]
    reverse_vocab_tgt[val] = key

In [51]:
train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
eval_loader = DataLoader(eval_dataset, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 32, shuffle = True)

In [52]:
class Encoder(nn.Module):
    def __init__(self, dims = 512, hidden_size = 512,num_layers = 2, max_src = 500, max_tgt = 500):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(len(vocab_src),dims)
        self.input_size = dims
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.max_src = max_src
        self.max_tgt = max_tgt
        self.drop = nn.Dropout(p = 0.5)
        self.encoder = nn.LSTM(
            dims, hidden_size, num_layers, batch_first = True,bidirectional=True, dropout = 0.5
        )
        
    def encode_inp(self, x):
        encoded_x = torch.zeros(len(x), self.max_src, dtype = int) + 85
        for i in range(len(x)):
            for j in range(len(x[i])):
                encoded_x[i][j] = vocab_src[x[i][j]]
            encoded_x[i][len(x[i])] = vocab_src["END"]
        return encoded_x.to(device)
    
    def forward(self, x):
        encoded_x = self.encode_inp(x)
        input_seq = self.drop(self.embedding(encoded_x))
        hidden = torch.zeros(2*self.num_layers,input_seq.shape[0],self.hidden_size).to(device)
        cell = torch.zeros(2*self.num_layers,input_seq.shape[0],self.hidden_size).to(device)
        out, _ = self.encoder(input_seq,(hidden, cell))
        return out

    def inference(self, x):
        self.max_src = len(x)
        

In [53]:
def calculate_diversity_penalty(new_sequence, existing_sequences):
    """
    Calculate a diversity penalty based on the new sequence and existing sequences.
    This is a simple example of diversity penalty calculation and can be customized.

    Args:
    - new_sequence (torch.Tensor): The new sequence to be penalized.
    - existing_sequences (list of torch.Tensor): A list of existing sequences.

    Returns:
    - float: The diversity penalty score.
    """
    penalty = 0.0
    for seq in existing_sequences:
        similarity = torch.sum(torch.eq(new_sequence, seq[0]).float()) / len(new_sequence)
        penalty += similarity
    return penalty

def beam_search_decoder(probabilities, beam_width, max_length, diversity_penalty_weight=5):
    """
    Beam search decoder for sequence generation.

    Args:
    - probabilities (torch.Tensor): A 2D tensor of shape (sequence_length, vocab_size)
      containing the predicted probabilities for each token at each time step.
    - beam_width (int): The number of sequences to consider at each decoding step.
    - max_length (int): The maximum length of the generated sequence.

    Returns:
    - List of tuples, each containing (sequence, score), where:
      - sequence (list): A list of token IDs representing the generated sequence.
      - score (float): The log-likelihood score of the sequence.
    """
    out  = torch.argmax(nn.Softmax(dim = 1)(probabilities), dim = 1)
    seq_len = 0
    for char in out:
        if(char == 45):
            break
        else:
            seq_len += 1

    # Get the sequence length and vocabulary size
    sequence_length, vocab_size = probabilities.shape
    sequence_length = seq_len
    max_length = seq_len

    # Initialize the beam with the empty sequence
    beam = [(torch.tensor([], dtype=torch.long).to(device), 0.0)]

    # Iterate through each time step
    for t in range(max_length):
        new_beam = []

        # Expand the beam by considering the top 'beam_width' candidates at each step
        for sequence, score in beam:
            # If the sequence is already at the maximum length, keep it as is
            if len(sequence) == max_length:
                new_beam.append((sequence, score))
                continue

            # Get the probabilities for the next token
            t_probs = probabilities[t]

            # Get the top 'beam_width' token IDs and their corresponding log-likelihood scores
            top_scores, top_tokens = torch.topk(t_probs, beam_width)

            # Expand the current sequence with each of the top tokens
            for token, token_score in zip(top_tokens, top_scores):
                new_sequence = torch.cat([sequence, token.unsqueeze(0)], dim=0)
                new_score = score + token_score.item()
    
                # Apply the diversity penalty
                if len(new_sequence) > 1:
                    # Calculate a penalty based on sequence diversity
                    diversity_penalty = diversity_penalty_weight * calculate_diversity_penalty(new_sequence, new_beam)
                    new_score -= diversity_penalty
                    
                new_beam.append((new_sequence, new_score))
        print(t)

        # Keep the top 'beam_width' candidates
        new_beam.sort(key=lambda x: -x[1])
        beam = new_beam[:beam_width]

    # Return the top sequence and its score
    return [(sequence.tolist(), score) for sequence, score in beam]

In [54]:
class Decoder(nn.Module):
    def __init__(self,dims = 512, hidden_size = 512,num_layers = 2, max_src = 500, max_tgt = 500):
        super(Decoder, self).__init__()
#         self.embedding = embedding
        self.embedding = nn.Embedding(len(vocab_tgt),dims)
        self.input_size = dims
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.max_src = max_src
        self.max_tgt = max_tgt
        self.drop = nn.Dropout(p = 0.5)
        self.dec_cells = nn.ModuleList([nn.LSTMCell(2*hidden_size+dims, hidden_size), nn.LSTMCell(hidden_size, hidden_size)])
        self.linear = nn.Linear(hidden_size,len(vocab_tgt)-1)
    
    def encode_inp(self, x):
        encoded_x = torch.zeros(len(x), self.max_tgt+1, dtype = int) + 46
        encoded_x_end = torch.zeros(len(x))
        for i in range(len(x)):
            encoded_x[i][0] = vocab_tgt["STR"]
            for j in range(len(x[i])):
                encoded_x[i][j+1] = vocab_tgt[x[i][j]]
            encoded_x[i][len(x[i])+1] = vocab_tgt["END"]
            encoded_x_end[i] = len(x[i])+1
        return encoded_x.to(device), encoded_x_end
    
    def calcontext(self, timestep, query):
        extended_query = torch.cat((query, query), dim = 1)
        permuted_context = self.context.permute(1,0,2)
#         for encoder_timestep in range(self.context.shape[1]):
#             scores.append(torch.sum(self.context[:,encoder_timestep] * extended_query, dim = 1, keepdims=True))
#         scores = torch.cat(scores, dim = 1)
        scores = torch.sum(permuted_context * extended_query, dim = 2).permute(1,0)
        weights = nn.Softmax(dim = 1)(scores).unsqueeze(2)
        alignment = torch.sum(weights * self.context,  dim = 1)
        
        return alignment
        
    
    def forward(self, context, target_,teacher_ratio):
        self.context = context
        encoded_x, encoded_x_end = self.encode_inp(target_)
        target_seq = self.embedding(encoded_x)
        
        initial_hidden1 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        initial_cell1 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        initial_hidden2 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        initial_cell2 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        
        outputs = []
        hidden_states = []
        cell_states = []
        query = [initial_hidden2]
        for timestep in range(self.max_tgt+1):
            if(timestep == 0):
                (h_t1, c_t1) = self.dec_cells[0](self.drop(torch.cat((target_seq[:,timestep],self.calcontext(0,query[-1])),dim=1)), (initial_hidden1, initial_cell1))
                (h_t2, c_t2) = self.dec_cells[1](self.drop(h_t1), (initial_hidden2, initial_cell2))
            else:
                input = []
                if(torch.rand(1).item() < teacher_ratio):
                    input = target_seq[:,timestep]
                else:
                    input = self.embedding(torch.argmax(nn.Softmax(dim = 1)(outputs[-1][:,0]),dim=1))
                    
                (h_t1, c_t1) = self.dec_cells[0](self.drop(torch.cat((input,self.calcontext(timestep,query[-1])),dim=1)), (hidden_states[-1][0], cell_states[-1][0]))
                (h_t2, c_t2) = self.dec_cells[1](self.drop(h_t1), (hidden_states[-1][1], cell_states[-1][1]))
            hidden_states.append([h_t1, h_t2])
            cell_states.append([c_t1, c_t2])
            query.append(h_t2)
            out = self.linear(h_t2)
            outputs.append(out.unsqueeze(1))
    

        output_prob = torch.cat(outputs,dim = 1)
        
        return nn.LogSoftmax(dim = 2)(output_prob), encoded_x
    
    def seq_to_vis(self, seq):
        vis = ""
        for char in seq:
            char = char
            if(char == 45):
                return vis
            vis += reverse_vocab_tgt[char]
        
        return vis


    def inference(self, context):
        self.context = context
        start = torch.zeros(1,1, dtype = int).to(device) + 44
        target_seq = self.embedding(start)
        
        initial_hidden1 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        initial_cell1 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        initial_hidden2 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        initial_cell2 = torch.rand(target_seq.shape[0], self.hidden_size).to(device)
        
        outputs = []
        hidden_states = []
        cell_states = []
        query = [initial_hidden2]
        for timestep in range(self.max_tgt+1):
            if(timestep == 0):
                (h_t1, c_t1) = self.dec_cells[0](self.drop(torch.cat((target_seq[:,timestep],self.calcontext(0,query[-1])),dim=1)), (initial_hidden1, initial_cell1))
                (h_t2, c_t2) = self.dec_cells[1](self.drop(h_t1), (initial_hidden2, initial_cell2))
            else:
                input = self.embedding(torch.argmax(nn.Softmax(dim = 1)(outputs[-1][:,0]),dim=1))
                    
                (h_t1, c_t1) = self.dec_cells[0](self.drop(torch.cat((input,self.calcontext(timestep,query[-1])),dim=1)), (hidden_states[-1][0], cell_states[-1][0]))
                (h_t2, c_t2) = self.dec_cells[1](self.drop(h_t1), (hidden_states[-1][1], cell_states[-1][1]))
            hidden_states.append([h_t1, h_t2])
            cell_states.append([c_t1, c_t2])
            query.append(h_t2)
            out = self.linear(h_t2)
            outputs.append(out.unsqueeze(1))
    
        output_prob = torch.cat(outputs,dim = 1)
        out = torch.argmax(nn.Softmax(dim = 2)(output_prob), dim = 2)
        
#         Example usage:
#         Replace 'probabilities', 'beam_width', and 'max_length' with your actual values
#         probabilities = torch.tensor(...)  # Shape: (sequence_length, vocab_size)
#         beam_width = 3
#         max_length = 10
        decoded_sequences = beam_search_decoder(output_prob[0], beam_width = 15, max_length = 500)
        vis = []
        for sequence, score in decoded_sequences:
            vis.append(self.seq_to_vis(sequence))
            # print(f"Sequence: {sequence}, {self.seq_to_vis(sequence)}, Log-Likelihood Score: {score}")

        return vis

    


In [55]:
encoder = Encoder()
decoder = Decoder()
encoder.to(device)
decoder.to(device)
optimizer_encoder = optim.Adam(encoder.parameters(), lr=0.0001)
optimizer_decoder = optim.Adam(decoder.parameters(), lr=0.0001)
criterion = nn.NLLLoss(ignore_index = 46)

# Set up early stopping parameters
patience = 2000  # Number of epochs to wait for improvement
best_val_loss = float('inf')
epochs_since_improvement = 0
loss_val = 0.0

In [56]:
for epoch in range(4):
  batch = 1
  for data, vis in train_loader:
    # Training Loop
      encoder.train()
      decoder.train()
      optimizer_encoder.zero_grad()
      optimizer_decoder.zero_grad()
      
      context = encoder(data)
      output, encoded_tgt = decoder(context, vis, 1.0)
      output = output.permute(0,2,1)
      loss = criterion(output[:,:,:-1], encoded_tgt[:,1:])
      
      loss.backward()
      optimizer_decoder.step()
      optimizer_encoder.step()
      
      print(f"Epoch[{epoch}], Batch[{batch}], Train loss: {loss.item()}")
        
    # Validation (with one minibatch only)
      with torch.no_grad():
        encoder.eval()
        decoder.eval()
        for data, vis in eval_loader:
          context_ = encoder(data)
          output_, encoded_tgt_ = decoder(context_,vis,1.0)
          output_ = output_.permute(0,2,1)
          loss_val = criterion(output_[:,:,:-1], encoded_tgt_[:,1:])
          print(f"Epoch[{epoch}], Val loss: {loss_val.item()}")
        
          checkpoint_encoder = {
                'model_state_dict': encoder.state_dict(),
                'optimizer_state_dict': optimizer_encoder.state_dict(),
          }
          checkpoint_decoder = {
                'model_state_dict': decoder.state_dict(),
                'optimizer_state_dict': optimizer_decoder.state_dict(),
          }
          
          if loss_val < best_val_loss:
              best_val_loss = loss_val
              epochs_since_improvement = 0
              torch.save(checkpoint_encoder, 'Model/encoder.pth')
              torch.save(checkpoint_decoder, 'Model/decoder.pth')
          else:
              epochs_since_improvement += 1
          break

      # Check if we should stop training early
      if epochs_since_improvement >= patience:
        print(f"Early stopping after {epoch} epochs , {batch} batches with no improvement.")
        break
      batch += 1
      
  if loss_val < best_val_loss:
      torch.save(checkpoint_encoder, 'Model/encoder.pth')
      torch.save(checkpoint_decoder, 'Model/decoder.pth')
        
      

Epoch[0], Batch[1], Train loss: 3.836839199066162
Epoch[0], Val loss: 3.8171885013580322
Epoch[0], Batch[2], Train loss: 3.8227381706237793
Epoch[0], Val loss: 3.795388698577881
Epoch[0], Batch[3], Train loss: 3.8073503971099854
Epoch[0], Val loss: 3.775780439376831
Epoch[0], Batch[4], Train loss: 3.791015863418579
Epoch[0], Val loss: 3.7505481243133545
Epoch[0], Batch[5], Train loss: 3.772714853286743
Epoch[0], Val loss: 3.7165732383728027
Epoch[0], Batch[6], Train loss: 3.752525806427002
Epoch[0], Val loss: 3.679064989089966
Epoch[0], Batch[7], Train loss: 3.7234628200531006
Epoch[0], Val loss: 3.6283559799194336
Epoch[0], Batch[8], Train loss: 3.6851515769958496
Epoch[0], Val loss: 3.578216075897217
Epoch[0], Batch[9], Train loss: 3.643625259399414
Epoch[0], Val loss: 3.524170398712158
Epoch[0], Batch[10], Train loss: 3.59562087059021
Epoch[0], Val loss: 3.458437204360962
Epoch[0], Batch[11], Train loss: 3.538316011428833
Epoch[0], Val loss: 3.388319730758667
Epoch[0], Batch[12], Tr

KeyboardInterrupt: 

In [57]:
encoder_ = Encoder().to(device)
decoder_ = Decoder().to(device)
checkpoint_enc = torch.load('Model/encoder.pth',map_location=torch.device('cpu'))
checkpoint_dec = torch.load('Model/decoder.pth',map_location=torch.device('cpu'))
encoder_.load_state_dict(checkpoint_enc['model_state_dict'])
decoder_.load_state_dict(checkpoint_dec['model_state_dict'])
encoder_.eval()
decoder_.eval()
# optimizer_encoder = optim.Adam(encoder_.parameters(), lr=0.0001)
# optimizer_decoder = optim.Adam(decoder_.parameters(), lr=0.0001)
optimizer_encoder.load_state_dict(checkpoint_enc['optimizer_state_dict'])
optimizer_decoder.load_state_dict(checkpoint_dec['optimizer_state_dict'])

criterion = nn.NLLLoss(ignore_index = 46)

# with torch.no_grad():
total_test = 0.0
batch = 1
for data, vis in train_loader:
  context_ = encoder_(data)
  output_, encoded_tgt_ = decoder_(context_,vis,1.0)
  output_ = output_.permute(0,2,1)
  loss_test = criterion(output_[:,:,:-1], encoded_tgt_[:,1:])
  total_test += loss_test.item()
  print(f"Batch [{batch}], loss: {loss_test.item()}")
  batch += 1
print(f"Average log perplexity on the test set: {total_test / len(test_loader)}")

Batch [1], loss: 0.03648332133889198
Batch [2], loss: 0.03610321879386902
Batch [3], loss: 0.037116486579179764
Batch [4], loss: 0.03736161068081856
Batch [5], loss: 0.03656899183988571
Batch [6], loss: 0.038916781544685364
Batch [7], loss: 0.03733036667108536
Batch [8], loss: 0.0360385924577713
Batch [9], loss: 0.03690795600414276
Batch [10], loss: 0.03645802661776543
Batch [11], loss: 0.035165682435035706
Batch [12], loss: 0.03644541651010513
Batch [13], loss: 0.03623410314321518
Batch [14], loss: 0.03796388581395149
Batch [15], loss: 0.03917788714170456
Batch [16], loss: 0.033612918108701706
Batch [17], loss: 0.03582746163010597
Batch [18], loss: 0.03809986263513565
Batch [19], loss: 0.03623787313699722
Batch [20], loss: 0.03773220628499985
Batch [21], loss: 0.03618108853697777
Batch [22], loss: 0.0351506844162941
Batch [23], loss: 0.03792411461472511
Batch [24], loss: 0.03622032701969147
Batch [25], loss: 0.03626759722828865
Batch [26], loss: 0.03695521131157875
Batch [27], loss: 0

KeyboardInterrupt: 

In [58]:
def inference(model, context):
    model.context = context
    start = torch.zeros(1,1, dtype = int).to(device) + 44
    target_seq = model.embedding(start)
    
    initial_hidden1 = torch.rand(target_seq.shape[0], model.hidden_size).to(device)
    initial_cell1 = torch.rand(target_seq.shape[0], model.hidden_size).to(device)
    initial_hidden2 = torch.rand(target_seq.shape[0], model.hidden_size).to(device)
    initial_cell2 = torch.rand(target_seq.shape[0], model.hidden_size).to(device)
    
    outputs = []
    hidden_states = []
    cell_states = []
    query = [initial_hidden2]
    for timestep in range(model.max_tgt+1):
        if(timestep == 0):
            (h_t1, c_t1) = model.dec_cells[0](model.drop(torch.cat((target_seq[:,timestep],model.calcontext(0,query[-1])),dim=1)), (initial_hidden1, initial_cell1))
            (h_t2, c_t2) = model.dec_cells[1](model.drop(h_t1), (initial_hidden2, initial_cell2))
        else:
            input = model.embedding(torch.argmax(nn.Softmax(dim = 1)(outputs[-1][:,0]),dim=1))
                
            (h_t1, c_t1) = model.dec_cells[0](model.drop(torch.cat((input,model.calcontext(timestep,query[-1])),dim=1)), (hidden_states[-1][0], cell_states[-1][0]))
            (h_t2, c_t2) = model.dec_cells[1](model.drop(h_t1), (hidden_states[-1][1], cell_states[-1][1]))
        hidden_states.append([h_t1, h_t2])
        cell_states.append([c_t1, c_t2])
        query.append(h_t2)
        out = model.linear(h_t2)
        outputs.append(out.unsqueeze(1))

    output_prob = torch.cat(outputs,dim = 1)
    out = torch.argmax(nn.Softmax(dim = 2)(output_prob), dim = 2)
    
#         Example usage:
#         Replace 'probabilities', 'beam_width', and 'max_length' with your actual values
#         probabilities = torch.tensor(...)  # Shape: (sequence_length, vocab_size)
#         beam_width = 3
#         max_length = 10
    decoded_sequences = beam_search_decoder(output_prob[0], beam_width = 15, max_length = 500, diversity_penalty_weight=0.7)
    vis = []
    for sequence, score in decoded_sequences:
        vis.append(model.seq_to_vis(sequence))
        # print(f"Sequence: {sequence}, {self.seq_to_vis(sequence)}, Log-Likelihood Score: {score}")

    return vis

In [59]:
with open('Data/A3 files/progression.txt') as f:
    progression, progression_transformed = f.readlines()
    progression_transformed = [progression_transformed]

with torch.no_grad():
    encoder_.eval()
    decoder_.eval()
    context_ = encoder_(progression_transformed)
    output = inference(decoder_,context_)
    for vis in output:
        print(vis)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "detail": {"field": "num0", "type": "ordinal"}, "y": {"field": "*", "aggregate": "count", "type": "quantitative"}}, "mark": "bar"}

{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "det

In [60]:
for vis in output:
    print(vis)

{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "detail": {"field": "num0", "type": "ordinal"}, "y": {"field": "*", "aggregate": "count", "type": "quantitative"}}, "mark": "bar"}

{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "detail": {"field": "num0", "type": "ordinal"}, "y": {"field": "*", "aggregate": "count", "type": "quantitative"}}, "mark": "bar"},
{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "detail": {"field": "num0", "type": "ordinal"}, "y": {"field": "*", "aggregate": "count", "type": "quantitative"}}, "mark": "bar"}
{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "detail": {"field": "num0", "type": "ordinal"}, "y": {"field": "*", "aggregate": "count", "type": "quantitative"}}, "mark": "bar"}}
{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "detail": {"field": "num0", "type": "ordinal"}, "y": {"field": "*", "aggregate": "count", "type": 

In [None]:
with open('Data/A3 files/progression.txt') as f:
    progression, progression_transformed = f.readlines()
    progression_transformed = [progression_transformed]

with torch.no_grad():
    encoder_.eval()
    decoder_.eval()
    context_ = encoder_(progression_transformed)
    output = inference(decoder_,context_)
    for vis in output:
        print(vis)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "y": {"field": "num0", "type": "ordinal"}, "detail": {"field": "*", "aggregate": "count", "type": "quantitative"}}, "mark": "bar"}

{"encoding": {"x": {"field": "num2", "type": "quantitative", "bin": true}, "y":