In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import random
import numpy as np
import tqdm
import math

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def build_dictionary(dictionary_file_location):
        text_file = open(dictionary_file_location,"r")
        full_dictionary = text_file.read().splitlines()
        text_file.close()
        return full_dictionary
    
full_dictionary_location = "words_250000_train.txt"
full_dictionary = build_dictionary(full_dictionary_location)

In [6]:
class Dict_Dataset(Dataset):
    def __init__(self, dictionary):
        self.words = dictionary
        # Create a mapping from each letter to its corresponding index (0-25)
        self.alphabets = 'abcdefghijklmnopqrstuvwxyz'
        self.CHAR_TO_INDEX = {char: idx for idx, char in enumerate(self.alphabets)}

        self.letter_weight = {}
        for i in dictionary:
            for l in self.alphabets:
                if l in i:
                    if self.letter_weight.get(l):
                        self.letter_weight[l] += 1
                    else:
                        self.letter_weight[l] = 1

    def __len__(self):
        return len(self.words)
    
    def cnt_to_guesses(self,char_set,cnt):
        lst = list(char_set)
        return list(np.random.choice(lst, cnt, p=[self.letter_weight[i] for i in lst]/np.sum([self.letter_weight[i] for i in lst]), replace=False))
    
    def one_hot_encode(self,char):
        """Convert a character to a one-hot vector."""
        vec = torch.zeros(28)
        vec[self.CHAR_TO_INDEX[char]] = 1.0
        return vec

    def word_to_matrix(self,word, correct_guesses,wrong_guesses):
        word = word.lower()  # Ensure the word is lowercase
        matrix = torch.zeros(38, 28)  # Initialize a (38, 28) matrix with zeros 27,28 digits for depicting blanks and wrong vector respect.
        # matrix[:len(word),26]=1  ## depicting word length in our sequence
        # print(matrix)
        for i, char in enumerate(word):
            if char in correct_guesses:
                # print(char)
                matrix[i] = self.one_hot_encode(char)
        # print('rssa')
        for i, char in enumerate(wrong_guesses):
            # print(char)
            matrix[32+i] = self.one_hot_encode(char)  
            matrix[32+i,27]=1
        matrix[:len(word),26]=1  ## depicting word length in our sequence
        return matrix

    def multi_encode(self,set_char):
        string = ''.join(set_char)
        vec = torch.zeros(26)
        for char in string:
            vec[self.CHAR_TO_INDEX[char]] = 1.0
        return vec
    
    def __getitem__(self, idx):
        wrd = self.words[idx]
        set_alpha = set(wrd)

        if len(set_alpha)==1:
            return self.word_to_matrix( wrd, '', '' ), self.multi_encode(set_alpha)
        
        if 'e' in set_alpha:
            cnt_correct_guess = np.random.randint(len(set_alpha)-1) ## -1 for e, -1 for atleast one unguessed
            cnt_incorrect_guess = np.random.randint(6) ## 0 to 5
            correct_guesses = self.cnt_to_guesses(set_alpha-set('e'),cnt_correct_guess)
            wrong_guesses = self.cnt_to_guesses(set(self.alphabets)-set_alpha,cnt_incorrect_guess)
            return self.word_to_matrix( wrd, ''.join(correct_guesses)+'e', ''.join(wrong_guesses) ), self.multi_encode(set_alpha - set(''.join(correct_guesses)+'e'))
        
        elif 'a' in set_alpha:
            cnt_correct_guess = np.random.randint(len(set_alpha)-1) ## -1 for a, -1 for atleast one unguessed
            cnt_incorrect_guess = np.random.randint(5) ## 0 to 4 , one for 'e'
            correct_guesses = self.cnt_to_guesses(set_alpha-set('a'),cnt_correct_guess)
            wrong_guesses = self.cnt_to_guesses(set(self.alphabets)-set_alpha-set('e'),cnt_incorrect_guess)
            return self.word_to_matrix( wrd, ''.join(correct_guesses)+'a', ''.join(wrong_guesses)+'e' ), self.multi_encode(set_alpha - set(''.join(correct_guesses)+'a'))
        
        elif 'i' in set_alpha:
            cnt_correct_guess = np.random.randint(len(set_alpha)-1) ## -1 for i, -1 for atleast one unguessed
            cnt_incorrect_guess = np.random.randint(4) ## 0 to 3 , two for 'e','a'
            correct_guesses = self.cnt_to_guesses(set_alpha-set('i'),cnt_correct_guess)
            wrong_guesses = self.cnt_to_guesses(set(self.alphabets)-set_alpha-set('ea'),cnt_incorrect_guess)
            return self.word_to_matrix( wrd, ''.join(correct_guesses)+'i', ''.join(wrong_guesses)+'ea' ), self.multi_encode(set_alpha - set(''.join(correct_guesses)+'i'))
        
        elif 'o' in set_alpha:
            cnt_correct_guess = np.random.randint(len(set_alpha)-1) ## -1 for o, -1 for atleast one unguessed
            cnt_incorrect_guess = np.random.randint(3) ## 0 to 2 , three for 'e','a','i'
            correct_guesses = self.cnt_to_guesses(set_alpha-set('o'),cnt_correct_guess)
            wrong_guesses = self.cnt_to_guesses(set(self.alphabets)-set_alpha-set('eai'),cnt_incorrect_guess)
            return self.word_to_matrix( wrd, ''.join(correct_guesses)+'o', ''.join(wrong_guesses)+'eai' ), self.multi_encode(set_alpha - set(''.join(correct_guesses)+'o'))
        
        else:
            return self.word_to_matrix( wrd, '', 'eaio' ), self.multi_encode(set_alpha)


In [7]:
dataset = Dict_Dataset(full_dictionary)

In [8]:
# Creating the Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Dimensionality of the model
        self.seq_len = seq_len # Maximum sequence length
        self.dropout = nn.Dropout(dropout) # Dropout layer to prevent overfitting

        # Creating a positional encoding matrix of shape (seq_len, d_model) filled with zeros
        pe = torch.zeros(seq_len, d_model)

        # Creating a tensor representing positions (0 to seq_len - 1)
        position = torch.arange(0, seq_len, dtype = torch.float).unsqueeze(1) # Transforming 'position' into a 2D tensor['seq_len, 1']

        # Creating the division term for the positional encoding formula
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Apply sine to even indices in pe
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices in pe
        pe[:, 1::2] = torch.cos(position * div_term)

        # Adding an extra dimension at the beginning of pe matrix for batch handling
        pe = pe.unsqueeze(0)

        # Registering 'pe' as buffer. Buffer is a tensor not considered as a model parameter
        self.register_buffer('pe', pe)

    def forward(self,x):
        # Addind positional encoding to the input tensor X
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x) # Dropout for regularization

In [9]:
class TransformerEncoderClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, embed_size=64, num_heads=4, hidden_dim=128, num_layers=2, dropout=0.1,seq_len=38):
        super(TransformerEncoderClassifier, self).__init__()
        
        # Linear layer to project the input to the embedding size
        self.embedding = nn.Linear(input_dim, embed_size)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(embed_size, seq_len, dropout)
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Fully connected layer for classification
        self.fc = nn.Linear(embed_size, num_classes)
    
    def forward(self, x):
        # Input shape: (batch_size, seq_len, input_dim)
        
        # Project input to embedding size
        x = self.embedding(x)
        
        # Add positional encoding
        x = self.positional_encoding(x)
        
        # Transpose for transformer: (seq_len, batch_size, embed_size)
        x = x.permute(1, 0, 2)
        
        # Apply the transformer encoder
        x = self.transformer_encoder(x)
        
        # Pooling over the sequence: mean of all positions
        x = x.mean(dim=0)
        
        # Fully connected layer for classification
        output = self.fc(x)
        # output = nn.Sigmoid()(output)
        
        return output


In [12]:
# Hyperparameters
input_dim = 28  # Input size of each sequence element
seq_len = 38    # Sequence length
num_classes = 26  # Number of classes
embed_size = 64  # Embedding size
num_heads = 4    # Number of heads in multi-head attention
hidden_dim = 128  # Hidden dimension size in the feedforward layer
num_layers = 4  # Number of Transformer Encoder layers
dropout = 0.1    # Dropout rate

# Create a dataset and data loader
dataset = Dict_Dataset(full_dictionary)
train_set, val_set = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(dataset, batch_size=96, shuffle=True,num_workers=8)
val_dataloader = DataLoader(val_set, batch_size=32, shuffle=True,num_workers=4)
device = 'cuda:2'
# Initialize the model, loss function, and optimizer
model = TransformerEncoderClassifier(input_dim=input_dim, num_classes=num_classes, embed_size=embed_size,
                                     num_heads=num_heads, hidden_dim=hidden_dim, num_layers=num_layers,
                                     dropout=dropout,seq_len=seq_len)
model.load_state_dict(torch.load('models/best_model_2_49'))

model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
max_epoch_loss = 99
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        # print(outputs)
        # stop
        loss = criterion(outputs.cpu(), labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataset)
    if epoch_loss < max_epoch_loss:
        max_epoch_loss = epoch_loss
        torch.save(model.state_dict(), f'models/best_model_2_{epoch}')
        torch.save(model.state_dict(), f'models/best_model_2__best')
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm.tqdm(val_dataloader):
            outputs = model(inputs.to(device)).cpu()
            # print(torch.argmax(outputs.data, 1))
            # print(labels)
            predicted = torch.argmax(outputs.data, 1)
            total += labels.size(0)
            # print(torch.sum(torch.Tensor([labels[ind,i] for ind,i in enumerate(predicted)])))
            correct += torch.sum(torch.Tensor([labels[ind,i] for ind,i in enumerate(predicted)])).item()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')


print("Training complete.")


100%|██████████| 2368/2368 [00:41<00:00, 56.90it/s]


Epoch [1/50], Loss: 0.2552


100%|██████████| 1421/1421 [00:07<00:00, 201.87it/s]


Accuracy: 65.40%


100%|██████████| 2368/2368 [00:41<00:00, 56.70it/s]


Epoch [2/50], Loss: 0.2550


100%|██████████| 1421/1421 [00:06<00:00, 206.60it/s]


Accuracy: 65.41%


100%|██████████| 2368/2368 [00:43<00:00, 54.27it/s]


Epoch [3/50], Loss: 0.2548


100%|██████████| 1421/1421 [00:07<00:00, 193.49it/s]


Accuracy: 65.75%


100%|██████████| 2368/2368 [00:44<00:00, 53.49it/s]


Epoch [4/50], Loss: 0.2548


100%|██████████| 1421/1421 [00:07<00:00, 184.64it/s]


Accuracy: 65.46%


100%|██████████| 2368/2368 [00:44<00:00, 53.17it/s]


Epoch [5/50], Loss: 0.2546


100%|██████████| 1421/1421 [00:07<00:00, 201.27it/s]


Accuracy: 65.99%


100%|██████████| 2368/2368 [00:43<00:00, 54.75it/s]


Epoch [6/50], Loss: 0.2546


100%|██████████| 1421/1421 [00:06<00:00, 214.98it/s]


Accuracy: 66.05%


100%|██████████| 2368/2368 [00:42<00:00, 56.18it/s]


Epoch [7/50], Loss: 0.2548


100%|██████████| 1421/1421 [00:06<00:00, 216.36it/s]


Accuracy: 65.70%


100%|██████████| 2368/2368 [00:36<00:00, 64.06it/s]


Epoch [8/50], Loss: 0.2545


100%|██████████| 1421/1421 [00:08<00:00, 161.31it/s]


Accuracy: 65.62%


100%|██████████| 2368/2368 [00:38<00:00, 62.13it/s]


Epoch [9/50], Loss: 0.2543


100%|██████████| 1421/1421 [00:06<00:00, 229.12it/s]


Accuracy: 65.81%


100%|██████████| 2368/2368 [00:37<00:00, 62.49it/s]


Epoch [10/50], Loss: 0.2542


100%|██████████| 1421/1421 [00:06<00:00, 215.55it/s]


Accuracy: 65.71%


100%|██████████| 2368/2368 [00:38<00:00, 61.21it/s]


Epoch [11/50], Loss: 0.2546


100%|██████████| 1421/1421 [00:06<00:00, 226.07it/s]


Accuracy: 65.65%


100%|██████████| 2368/2368 [00:37<00:00, 63.02it/s]


Epoch [12/50], Loss: 0.2541


100%|██████████| 1421/1421 [00:06<00:00, 216.43it/s]


Accuracy: 65.90%


100%|██████████| 2368/2368 [00:38<00:00, 62.13it/s]


Epoch [13/50], Loss: 0.2538


100%|██████████| 1421/1421 [00:06<00:00, 226.48it/s]


Accuracy: 65.88%


100%|██████████| 2368/2368 [00:38<00:00, 61.68it/s]


Epoch [14/50], Loss: 0.2539


100%|██████████| 1421/1421 [00:06<00:00, 219.87it/s]


Accuracy: 65.91%


100%|██████████| 2368/2368 [00:37<00:00, 63.79it/s]


Epoch [15/50], Loss: 0.2539


100%|██████████| 1421/1421 [00:05<00:00, 238.68it/s]


Accuracy: 65.93%


100%|██████████| 2368/2368 [00:37<00:00, 62.99it/s]


Epoch [16/50], Loss: 0.2537


100%|██████████| 1421/1421 [00:06<00:00, 208.36it/s]


Accuracy: 65.75%


100%|██████████| 2368/2368 [00:44<00:00, 53.53it/s]


Epoch [17/50], Loss: 0.2536


100%|██████████| 1421/1421 [00:06<00:00, 208.67it/s]


Accuracy: 66.19%


100%|██████████| 2368/2368 [00:40<00:00, 57.93it/s]


Epoch [18/50], Loss: 0.2532


100%|██████████| 1421/1421 [00:06<00:00, 223.61it/s]


Accuracy: 66.21%


100%|██████████| 2368/2368 [00:38<00:00, 62.26it/s]


Epoch [19/50], Loss: 0.2538


100%|██████████| 1421/1421 [00:05<00:00, 237.48it/s]


Accuracy: 66.31%


100%|██████████| 2368/2368 [00:37<00:00, 63.63it/s]


Epoch [20/50], Loss: 0.2534


100%|██████████| 1421/1421 [00:06<00:00, 221.41it/s]


Accuracy: 66.19%


100%|██████████| 2368/2368 [00:37<00:00, 63.56it/s]


Epoch [21/50], Loss: 0.2540


100%|██████████| 1421/1421 [00:05<00:00, 248.19it/s]


Accuracy: 66.10%


100%|██████████| 2368/2368 [00:38<00:00, 61.43it/s]


Epoch [22/50], Loss: 0.2535


100%|██████████| 1421/1421 [00:06<00:00, 214.13it/s]


Accuracy: 65.99%


100%|██████████| 2368/2368 [00:38<00:00, 61.78it/s]


Epoch [23/50], Loss: 0.2532


100%|██████████| 1421/1421 [00:07<00:00, 191.91it/s]


Accuracy: 66.05%


100%|██████████| 2368/2368 [00:43<00:00, 53.93it/s]


Epoch [24/50], Loss: 0.2535


100%|██████████| 1421/1421 [00:07<00:00, 191.44it/s]


Accuracy: 66.13%


100%|██████████| 2368/2368 [00:45<00:00, 52.33it/s]


Epoch [25/50], Loss: 0.2532


100%|██████████| 1421/1421 [00:06<00:00, 226.17it/s]


Accuracy: 66.52%


100%|██████████| 2368/2368 [00:38<00:00, 61.16it/s]


Epoch [26/50], Loss: 0.2533


100%|██████████| 1421/1421 [00:06<00:00, 234.17it/s]


Accuracy: 65.82%


100%|██████████| 2368/2368 [00:39<00:00, 60.37it/s]


Epoch [27/50], Loss: 0.2533


100%|██████████| 1421/1421 [00:06<00:00, 222.92it/s]


Accuracy: 66.15%


100%|██████████| 2368/2368 [00:38<00:00, 62.01it/s]


Epoch [28/50], Loss: 0.2529


100%|██████████| 1421/1421 [00:06<00:00, 216.75it/s]


Accuracy: 66.56%


100%|██████████| 2368/2368 [00:38<00:00, 62.04it/s]


Epoch [29/50], Loss: 0.2534


100%|██████████| 1421/1421 [00:06<00:00, 231.97it/s]


Accuracy: 65.84%


100%|██████████| 2368/2368 [00:38<00:00, 61.31it/s]


Epoch [30/50], Loss: 0.2530


100%|██████████| 1421/1421 [00:06<00:00, 209.49it/s]


Accuracy: 66.17%


100%|██████████| 2368/2368 [00:38<00:00, 61.93it/s]


Epoch [31/50], Loss: 0.2530


100%|██████████| 1421/1421 [00:06<00:00, 221.14it/s]


Accuracy: 65.96%


100%|██████████| 2368/2368 [00:38<00:00, 61.65it/s]


Epoch [32/50], Loss: 0.2524


100%|██████████| 1421/1421 [00:06<00:00, 224.87it/s]


Accuracy: 66.15%


100%|██████████| 2368/2368 [00:37<00:00, 63.65it/s]


Epoch [33/50], Loss: 0.2530


100%|██████████| 1421/1421 [00:06<00:00, 204.18it/s]


Accuracy: 67.02%


100%|██████████| 2368/2368 [00:38<00:00, 61.35it/s]


Epoch [34/50], Loss: 0.2532


100%|██████████| 1421/1421 [00:06<00:00, 218.32it/s]


Accuracy: 66.33%


100%|██████████| 2368/2368 [00:38<00:00, 62.22it/s]


Epoch [35/50], Loss: 0.2527


100%|██████████| 1421/1421 [00:06<00:00, 236.28it/s]


Accuracy: 66.77%


100%|██████████| 2368/2368 [00:37<00:00, 63.78it/s]


Epoch [36/50], Loss: 0.2528


100%|██████████| 1421/1421 [00:06<00:00, 213.54it/s]


Accuracy: 65.95%


100%|██████████| 2368/2368 [00:37<00:00, 63.75it/s]


Epoch [37/50], Loss: 0.2527


100%|██████████| 1421/1421 [00:06<00:00, 210.81it/s]


Accuracy: 66.33%


100%|██████████| 2368/2368 [00:39<00:00, 59.70it/s]


Epoch [38/50], Loss: 0.2525


100%|██████████| 1421/1421 [00:06<00:00, 229.36it/s]


Accuracy: 66.27%


100%|██████████| 2368/2368 [00:37<00:00, 62.55it/s]


Epoch [39/50], Loss: 0.2522


100%|██████████| 1421/1421 [00:06<00:00, 222.38it/s]


Accuracy: 66.04%


100%|██████████| 2368/2368 [00:37<00:00, 63.87it/s]


Epoch [40/50], Loss: 0.2526


100%|██████████| 1421/1421 [00:06<00:00, 221.94it/s]


Accuracy: 66.11%


100%|██████████| 2368/2368 [00:38<00:00, 61.95it/s]


Epoch [41/50], Loss: 0.2528


100%|██████████| 1421/1421 [00:06<00:00, 214.66it/s]


Accuracy: 66.17%


100%|██████████| 2368/2368 [00:37<00:00, 62.62it/s]


Epoch [42/50], Loss: 0.2527


100%|██████████| 1421/1421 [00:06<00:00, 227.54it/s]


Accuracy: 66.44%


100%|██████████| 2368/2368 [00:37<00:00, 62.82it/s]


Epoch [43/50], Loss: 0.2526


100%|██████████| 1421/1421 [00:06<00:00, 209.24it/s]


Accuracy: 66.20%


100%|██████████| 2368/2368 [00:38<00:00, 61.29it/s]


Epoch [44/50], Loss: 0.2528


100%|██████████| 1421/1421 [00:06<00:00, 229.32it/s]


Accuracy: 66.69%


100%|██████████| 2368/2368 [00:38<00:00, 61.75it/s]


Epoch [45/50], Loss: 0.2521


100%|██████████| 1421/1421 [00:06<00:00, 233.10it/s]


Accuracy: 65.97%


100%|██████████| 2368/2368 [00:36<00:00, 64.46it/s]


Epoch [46/50], Loss: 0.2524


100%|██████████| 1421/1421 [00:05<00:00, 244.56it/s]


Accuracy: 66.46%


100%|██████████| 2368/2368 [00:38<00:00, 61.66it/s]


Epoch [47/50], Loss: 0.2523


100%|██████████| 1421/1421 [00:06<00:00, 217.39it/s]


Accuracy: 66.51%


100%|██████████| 2368/2368 [00:39<00:00, 60.21it/s]


Epoch [48/50], Loss: 0.2524


100%|██████████| 1421/1421 [00:05<00:00, 249.35it/s]


Accuracy: 66.14%


100%|██████████| 2368/2368 [00:37<00:00, 63.29it/s]


Epoch [49/50], Loss: 0.2526


100%|██████████| 1421/1421 [00:05<00:00, 245.07it/s]


Accuracy: 66.28%


100%|██████████| 2368/2368 [00:37<00:00, 63.78it/s]


Epoch [50/50], Loss: 0.2521


100%|██████████| 1421/1421 [00:06<00:00, 206.32it/s]

Accuracy: 66.32%
Training complete.





In [15]:
model2 = TransformerEncoderClassifier(input_dim=input_dim, num_classes=num_classes, embed_size=embed_size,
                                     num_heads=num_heads, hidden_dim=hidden_dim, num_layers=num_layers,
                                     dropout=dropout,seq_len=seq_len)

In [17]:
model2.load_state_dict(torch.load('models/best_model_2_0'))
model2.eval()

TransformerEncoderClassifier(
  (embedding): Linear(in_features=28, out_features=64, bias=True)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQu

In [19]:
model2(torch.zeros(38,28)).data,1

(tensor([[-4.8633, -2.9174, -3.1582, -3.2095, -5.3365, -3.3531, -3.5696, -3.3052,
          -4.1689, -2.8368, -3.6926, -4.0942, -2.6316, -4.2230, -4.2359, -3.0520,
          -3.7881, -4.0173, -3.4806, -3.8285, -4.1650, -2.6947, -3.2260, -3.9559,
          -3.9787, -3.5007]]),
 1)

In [23]:
torch.argsort(model2(torch.zeros(38,28)).data,1,descending=True)[0]

tensor([12, 21,  9,  1, 15,  2,  3, 22,  7,  5, 18, 25,  6, 10, 16, 19, 23, 24,
        17, 11, 20,  8, 13, 14,  0,  4])

In [13]:
torch.argmax(model2(torch.zeros(38,28)).data,1).item()

NameError: name 'model2' is not defined

In [None]:
model2(torch.zeros(38,28))

tensor([[ -3.3406,  -3.7175,  -3.9934,  -4.2455, -14.9520,  -3.7689,  -3.7389,
          -3.6935,  -3.1755,  -4.1289,  -3.5122,  -4.3666,  -3.5186,  -4.3590,
          -3.5265,  -3.9609,  -7.0674,  -3.9472,  -3.7219,  -4.3358,  -3.0697,
          -3.9584,  -3.5460,  -4.6119,  -2.9986,  -4.9754]],
       grad_fn=<AddmmBackward0>)