In [77]:
import torch
import pandas as pd

def load_arguments(filename: str):
    with open(filename, "r", encoding="utf8") as f:
        lines = f.readlines()
        headers = lines[0].strip().split("\t")
        records = [l.strip().split("\t") for l in lines[1:]]
    return headers, records


def load_arguments_as_df(filename: str):
    return pd.read_csv(filename, encoding="utf8", sep="\t").set_index("Argument ID")


def load_value_categories(filename: str):
    return pd.read_csv(filename, encoding="utf8", sep="\t").set_index("Argument ID")
# Encode labels to vector space

opposite_pairs = {
    "Self-direction: thought": (0, 1),
    "Self-direction: action": (1, 1),
    "Stimulation": (2, 1),
    "Hedonism": (3, 1),
    "Achievement": (4, 1),
    "Power: dominance": (5, 1),
    "Power: resources": (6, 1),
    "Face": (7, 1),
    "Security: personal": (8, 1),
    "Security: societal": (9, 1),
    "Tradition": (0, -1),
    "Conformity: rules": (1, -1),
    "Conformity: interpersonal": (2, -1),
    "Humility": (3, -1),
    "Benevolence: caring": (4, -1),
    "Benevolence: dependability": (5, -1),
    "Universalism: concern": (6, -1),
    "Universalism: nature": (7, -1),
    "Universalism: tolerance": (8, -1),
    "Universalism: objectivity": (9, -1),
}

pair_dict = {0: ['Self-direction: thought', 'Tradition'], 1: ['Self-direction: action', 'Conformity: rules'], 2: ['Stimulation', 'Conformity: interpersonal'], 3: ['Hedonism', 'Humility'], 4: ['Achievement', 'Benevolence: caring'],
             5: ['Power: dominance', 'Benevolence: dependability'], 6: ['Power: resources', 'Universalism: concern'], 7: ['Face', 'Universalism: nature'], 8: ['Security: personal', 'Universalism: tolerance'], 9: ['Security: societal', 'Universalism: objectivity']}


def encode_label(inputRow):
    output_vector = [0] * len(pair_dict)
    for k, v in pair_dict.items():
        output_vector[k] = inputRow[v[0]] + inputRow[v[1]] * -1
    return output_vector

def decode_label(embeddedVal, threshold=0.3):
    out = [0] * 20
    for i, v in pair_dict.items():
        if(embeddedVal[i] >= threshold):
            out[i] = 1
        elif(embeddedVal[i] <= (-threshold)):
            out[i + 10] = 1
    return out



In [86]:

# from helper import load_arguments, \
#     load_arguments_as_df,\
#     load_value_categories, encode_label

argument_file_path = "data/arguments-training.tsv"
headers, argument_list = load_arguments(argument_file_path)
argument_df = load_arguments_as_df(argument_file_path)

labels_file_path = "data/labels-training.tsv"
labels_df = load_value_categories(labels_file_path)

# validation
val_argument_file_path = "data/arguments-validation.tsv"
val_labels_file_path = "data/labels-validation.tsv"
val_headers, val_argument_list  = load_arguments(val_argument_file_path)
val_argument_df = load_arguments_as_df(val_argument_file_path)
val_labels_df = load_value_categories(val_labels_file_path)

# test
test_argument_file_path = "data/arguments-test.tsv"
# val_labels_file_path = "data/labels-validation.tsv"
test_headers, test_argument_list  = load_arguments(test_argument_file_path)
test_argument_df = load_arguments_as_df(test_argument_file_path)
# val_labels_df = load_value_categories(val_labels_file_path)

In [88]:
argument_df.shape

(5393, 3)

In [80]:
val_argument_df.shape

(1896, 3)

In [87]:
test_argument_df.shape

(1576, 3)

In [2]:
argument_df["Stance"].unique()

array(['in favor of', 'against', 'in favour of'], dtype=object)

In [85]:
argument_df.iloc[2]["Conclusion"]

'We should end the use of economic sanctions'

In [3]:
labels_df.head(5)

Unnamed: 0_level_0,Self-direction: thought,Self-direction: action,Stimulation,Hedonism,Achievement,Power: dominance,Power: resources,Face,Security: personal,Security: societal,Tradition,Conformity: rules,Conformity: interpersonal,Humility,Benevolence: caring,Benevolence: dependability,Universalism: concern,Universalism: nature,Universalism: tolerance,Universalism: objectivity
Argument ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
A01002,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
A01005,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0
A01006,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0
A01007,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,0
A01008,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0


In [4]:
# Clean arguments
argument_df["Stance_encoded"] = argument_df.apply(lambda row: 0 if row["Stance"] == "against" else 1, axis=1)

In [5]:
for s in labels_df.columns:
    print(f'"{s}"')

"Self-direction: thought"
"Self-direction: action"
"Stimulation"
"Hedonism"
"Achievement"
"Power: dominance"
"Power: resources"
"Face"
"Security: personal"
"Security: societal"
"Tradition"
"Conformity: rules"
"Conformity: interpersonal"
"Humility"
"Benevolence: caring"
"Benevolence: dependability"
"Universalism: concern"
"Universalism: nature"
"Universalism: tolerance"
"Universalism: objectivity"


In [6]:
torch.cat(tuple(labels_df.apply(lambda row: torch.Tensor(encode_label(row)).reshape(1,-1), axis=1)), dim=0)

tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  ...,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  ...,  0.,  0.,  1.],
        ...,
        [ 0.,  0.,  0.,  ..., -1.,  0.,  0.],
        [ 0., -1.,  0.,  ...,  0.,  1.,  0.],
        [ 0.,  1.,  0.,  ...,  0.,  0.,  1.]])

In [7]:
print("Max length conclusion: {}, premise: {}".format(argument_df["Conclusion"].map(lambda x: len(x.split(" "))).max(), argument_df["Premise"].map(lambda x: len(x.split(" "))).max()))
# Conclusion length 64, Premise length 256

Max length conclusion: 35, premise: 133


In [8]:
from transformers import BertTokenizer, BertModel
import numpy as np

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

sentence_model = BertModel.from_pretrained("prajjwal1/bert-small")

def encode_text(text:str, max_length:int):
    encoded = tokenizer.encode_plus(
        text=text,
        add_special_tokens=True,
        truncation=True,
        max_length = max_length,
        padding="max_length",
        return_attention_mask = True,
        return_tensors = "pt"
    )
    output = sentence_model(encoded["input_ids"])
    _, pooled_output = output[:2]
#     return torch.cat((encoded["input_ids"], encoded["attention_mask"]), dim=1)
    return pooled_output[0]

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
encode_text(argument_df.iloc[0][0], 64).size()

torch.Size([512])

In [10]:
# Build features
conclusions = list()
with torch.no_grad():
    for i, v in argument_df.iterrows():
        conclusions.append(encode_text(v["Conclusion"], 64).reshape(1,-1))
conclusions = torch.cat(conclusions, dim=0)

arguments = list()
with torch.no_grad():
    for i, v in argument_df.iterrows():
        arguments.append(encode_text(v["Premise"], 256).reshape(1,-1))
arguments = torch.cat(arguments, dim=0)

stance = torch.Tensor(tuple(argument_df["Stance_encoded"])).reshape(-1,1)

In [11]:
stance = stance.reshape(-1,1)

In [12]:
conclusions.size()
arguments.size()
stance.size()

torch.Size([5393, 1])

In [29]:
x_train = torch.cat((conclusions, arguments, stance), dim=1)
y_train = torch.cat(tuple(labels_df.apply(lambda row: torch.Tensor(encode_label(row)).reshape(1,-1), axis=1)), dim=0)

In [30]:
x_train.size()

torch.Size([5393, 1025])

In [33]:
torch.min(x_train, dim=1).values.size()

torch.Size([5393])

In [40]:
torch.max(x_train.abs(), dim=1, keepdim=True)[0].size()

torch.Size([5393, 1])

In [39]:
(x_train/torch.max(x_train.abs(), dim=1, keepdim=True)[0]).size()

torch.Size([5393, 1025])

In [None]:
x_train.size()

In [None]:
y_train.size()

https://stackoverflow.com/questions/68011633/train-multi-output-regression-model-in-pytorch

In [90]:
import torch

class HumanValueRegressor(torch.nn.Module):
    def __init__(self):
        super(HumanValueRegressor, self).__init__()
        self.hidden1 = torch.nn.Linear(1025, 2000)
        self.hidden2 = torch.nn.Linear(2000, 1000)
        self.hidden3 = torch.nn.Linear(1000, 800)
        self.hidden4 = torch.nn.Linear(800, 500)
        self.output = torch.nn.Linear(500, 10)
        
        torch.nn.init.xavier_uniform_(self.hidden1.weight)
        torch.nn.init.zeros_(self.hidden1.bias)
        torch.nn.init.xavier_uniform_(self.hidden2.weight)
        torch.nn.init.zeros_(self.hidden2.bias)
        torch.nn.init.xavier_uniform_(self.hidden3.weight)
        torch.nn.init.zeros_(self.hidden3.bias)
        torch.nn.init.xavier_uniform_(self.hidden4.weight)
        torch.nn.init.zeros_(self.hidden4.bias)
        torch.nn.init.xavier_uniform_(self.output.weight)
        torch.nn.init.zeros_(self.output.bias)
        
    def forward(self, x):
        x = torch.relu(self.hidden1(x))
        x = torch.relu(self.hidden2(x))
        x = torch.relu(self.hidden3(x))
        x = torch.relu(self.hidden4(x))
#         x = torch.tanh(self.output(x))
        x = self.output(x)
        return x/(torch.max(x.abs(), dim=1, keepdim=True)[0])
    
model = HumanValueRegressor()

In [42]:
from tqdm.notebook import tqdm
import random

def training_loop(
    in_features, 
    label_features,
    batch_size,
    epochs,
    model
):
    print("Training...")
    criterion = torch.nn.MSELoss()

    # Create batches
    batches = []
    for i in range(0, len(in_features), batch_size):
        batches.append(
        (
            in_features[i:i+batch_size],
            label_features[i:i+batch_size]
        ))
        
    random.shuffle(batches)

    for i in range(epochs):
        losses = []
        for features, labels in tqdm(batches):
            optimizer.zero_grad()
            out_pred = model(features)
            loss = criterion(out_pred, labels)
#             print(loss)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
    return model

In [91]:
# Training
LR = 0.000001
optimizer = torch.optim.Adam(model.parameters(), LR)
epochs = 100
batch_size = 8

trained_model = training_loop(
    x_train, 
    y_train,
    batch_size,
    epochs,
    model
)

Training...


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 0, loss: 0.380630055224454


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 1, loss: 0.3689585128095415


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 2, loss: 0.3676105394186797


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 3, loss: 0.36601903595306257


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 4, loss: 0.36392957972155676


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 5, loss: 0.36156692774207505


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 6, loss: 0.35935617018629


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 7, loss: 0.3575242621148074


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 8, loss: 0.35592696401808


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 9, loss: 0.35453618212982463


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 10, loss: 0.3532910475465986


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 11, loss: 0.3521299375648852


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 12, loss: 0.3510498618196558


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 13, loss: 0.3499722197099968


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 14, loss: 0.3488732511467404


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 15, loss: 0.3479040118279281


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 16, loss: 0.34701821155018275


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 17, loss: 0.34615129718074095


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 18, loss: 0.3453215373886956


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 19, loss: 0.34456672491850676


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 20, loss: 0.343805973706422


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 21, loss: 0.3430773952934477


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 22, loss: 0.3424312103456921


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 23, loss: 0.341596723485876


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 24, loss: 0.34087039949717346


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 25, loss: 0.34009764126053565


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 26, loss: 0.3393282250342546


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 27, loss: 0.33883676025602555


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 28, loss: 0.33682286346400225


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 29, loss: 0.33498678871878873


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 30, loss: 0.3325402898479391


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 31, loss: 0.3015060856827983


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 32, loss: 0.29436240238172035


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 33, loss: 0.29318671411938135


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 34, loss: 0.2925268644094467


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 35, loss: 0.2909067207574844


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 36, loss: 0.2903902288940218


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 37, loss: 0.2903652536869049


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 38, loss: 0.28941279654149654


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 39, loss: 0.28859706428315907


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 40, loss: 0.2880369782668573


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 41, loss: 0.2874571966462665


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 42, loss: 0.2870155252792217


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 43, loss: 0.28989360586360646


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 44, loss: 0.2897112865801211


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 45, loss: 0.28943974864703637


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 46, loss: 0.28436113832173526


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 47, loss: 0.2844356631680771


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 48, loss: 0.2854254350176564


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 49, loss: 0.2848458303124816


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 50, loss: 0.27952338499051554


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 51, loss: 0.27712494607324956


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 52, loss: 0.28260943785861686


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 53, loss: 0.277525310990987


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 54, loss: 0.2768168479645694


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 55, loss: 0.2733815524533943


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 56, loss: 0.27412545331098415


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 57, loss: 0.2741082254604057


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 58, loss: 0.26970243037850766


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 59, loss: 0.2693844668511991


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 60, loss: 0.2662626920254142


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 61, loss: 0.2665757580818953


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 62, loss: 0.2674764945109685


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 63, loss: 0.2656246705629208


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 64, loss: 0.26442401826381684


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 65, loss: 0.26382493908758514


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 66, loss: 0.26357953622385305


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 67, loss: 0.26316734897869604


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 68, loss: 0.2622307198577457


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 69, loss: 0.2620652961620578


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 70, loss: 0.261873660562215


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 71, loss: 0.26155607786443497


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 72, loss: 0.2610508546895451


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 73, loss: 0.25973954444682157


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 74, loss: 0.25985474660440727


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 75, loss: 0.260892819000615


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 76, loss: 0.25833308969382884


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 77, loss: 0.25773693067056164


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 78, loss: 0.25698505290128565


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 79, loss: 0.25635717735246377


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 80, loss: 0.25593467819469945


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 81, loss: 0.25578004390001297


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 82, loss: 0.2555265118236895


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 83, loss: 0.25493624166206075


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 84, loss: 0.2548559594706253


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 85, loss: 0.25475116963739747


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 86, loss: 0.2541785341611615


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 87, loss: 0.25378239266298436


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 88, loss: 0.2533350311274882


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 89, loss: 0.2528824180474988


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 90, loss: 0.25282794574896494


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 91, loss: 0.25215497233249523


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 92, loss: 0.2519500740038024


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 93, loss: 0.25138941249361746


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 94, loss: 0.25124471462435194


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 95, loss: 0.2501590263733157


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 96, loss: 0.24998180334214812


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 97, loss: 0.2495745614502165


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 98, loss: 0.24891649877583538


  0%|          | 0/675 [00:00<?, ?it/s]

epoch 99, loss: 0.24864840070406596


In [92]:
with torch.no_grad():
    x = trained_model(x_train)
x

tensor([[-0.1137, -0.2733, -0.0503,  ..., -0.2566,  0.7133, -0.1103],
        [-0.0462,  0.1249, -0.0374,  ...,  0.0365,  1.0000, -0.1155],
        [-0.0892, -0.0844,  0.0295,  ..., -0.0151,  0.0464,  1.0000],
        ...,
        [ 0.0583, -0.2532, -0.2229,  ..., -0.4944,  1.0000, -0.6332],
        [ 0.0698, -0.1679, -0.1749,  ..., -0.3183,  1.0000, -0.4107],
        [ 0.0287,  0.6973,  0.0945,  ...,  0.3519,  0.8214,  0.5342]])

In [93]:
x[0]

tensor([-0.1137, -0.2733, -0.0503, -0.1976, -0.0373, -0.0213, -1.0000, -0.2566,
         0.7133, -0.1103])

In [49]:
y_train

tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  1.],
        [ 0.,  0.,  0.,  ...,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  ...,  0.,  0.,  1.],
        ...,
        [ 0.,  0.,  0.,  ..., -1.,  0.,  0.],
        [ 0., -1.,  0.,  ...,  0.,  1.,  0.],
        [ 0.,  1.,  0.,  ...,  0.,  0.,  1.]])

In [96]:
(x - y_train).square().mean()

tensor(0.2478)

In [99]:
decode_label(y_train[0])

[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [95]:
decode_label(x[0])

[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]

In [89]:
print(torch.__version__)

1.12.1+cu102
