<a href="https://colab.research.google.com/gist/22961-Deep-learning/27700ad4f979c1760672619a577cc209/22961_6_2_embedding_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn
import datasets as ds
from pprint import pprint
from tqdm import tqdm

dataset = ds.load_dataset("glue", "sst2")

sentence_list=dataset["train"]["sentence"]
labels_list=dataset["train"]["label"]
tokenize = lambda x: x.split()
tokenized=list(map(tokenize,sentence_list))

from torchtext.vocab import build_vocab_from_iterator
vocab=build_vocab_from_iterator(tokenized, specials=["<UNK>"],min_freq=5)
vocab.set_default_index(1)

func = lambda x: torch.tensor(vocab(x))
integer_tokens=list(map(func,tokenized))
label_tensors=list(map(torch.tensor,labels_list))
print(*sentence_list[1:3],sep="\n")
print(*integer_tokens[1:3],sep="t\n")
print(*label_tensors[1:3],sep="\n")

Reusing dataset glue (/Users/shlomidomnenko/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 3/3 [00:00<00:00, 365.13it/s]


contains no wit , only labored gags 
that loves its characters and communicates something rather beautiful about human nature 
tensor([2924,   61,  330,    2,   89, 1993,  549])t
tensor([  10, 1792,   17,   54,    4, 6088,   96,  186,  265,   34,  178,  627])
tensor(0)
tensor(1)


In [3]:
test_split=len(integer_tokens)*8//10
train_tokens=integer_tokens[:test_split]
train_labels=label_tensors[:test_split]
test_tokens=integer_tokens[test_split:]
test_labels=label_tensors[test_split:]

In [4]:
class ClassificationHead(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.linear = nn.Linear(in_features, 2)
        self.logsoftmax = nn.LogSoftmax(dim=0)

    def forward(self, feature_extractor_output):
        class_scores= self.linear(feature_extractor_output)
        logprobs    = self.logsoftmax(class_scores)
        return logprobs

In [5]:
class FeatureExtractor_1(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(len(vocab),embed_dim)

    def forward(self, sentence_tokens):
        embedded    = self.embedding(sentence_tokens)
        return embedded

In [6]:
example_sentence=sentence_list[1]

In [7]:
print(example_sentence)

preprocess= lambda x: torch.tensor(vocab(x.split()))
tokens=preprocess(example_sentence)
print(tokens)

extractor=FeatureExtractor_1(2)
features=extractor(tokens)
print(features,features.size(),sep="\n")


contains no wit , only labored gags 
tensor([2924,   61,  330,    2,   89, 1993,  549])
tensor([[-0.4834, -1.3486],
        [ 1.2447, -0.5322],
        [ 0.7654, -1.4073],
        [ 0.4062,  0.4013],
        [-1.0043, -0.4094],
        [-0.4944, -0.4208],
        [-0.9554, -0.2330]], grad_fn=<EmbeddingBackward0>)
torch.Size([7, 2])


In [8]:
class FeatureExtractor(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(len(vocab),embed_dim)

    def forward(self, sentence_tokens):
        embedded    = self.embedding(sentence_tokens)
        feature_extractor_output = embedded.sum(dim=0)    #
        return feature_extractor_output

In [9]:
extractor=FeatureExtractor(2)
features=extractor(tokens)
print(features,features.size(),sep="\n")

tensor([ 0.5171, -2.3180], grad_fn=<SumBackward1>)
torch.Size([2])


In [10]:
class EmbedSumClassify(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.extractor  = FeatureExtractor(embed_dim)
        self.classifier = ClassificationHead(embed_dim)

    def forward(self, sentence_tokens):
        extracted_features = self.extractor(sentence_tokens)
        logprobs    = self.classifier(extracted_features)
        return logprobs

In [11]:
model=EmbedSumClassify(2)
print(model(tokens))

tensor([-2.4344, -0.0917], grad_fn=<LogSoftmaxBackward0>)


In [12]:
def iterate_one_sentence(tokens,label,train_flag):
  tokens=tokens
  if train_flag:
    model.train()  
    optimizer.zero_grad()
    y_model=model(tokens)
    loss= -y_model[label] #CE loss
    loss.backward()
    optimizer.step()
  else:
    model.eval()
    y_model=model(tokens)
    model.train()
  with torch.no_grad():
    predicted_labels = y_model.argmax(dim=0)
    success = (predicted_labels == label)
  return success

def train_one_epoch():
  correct_predictions=torch.tensor([0.])
  for tokens,label in tqdm(zip(train_tokens,train_labels),total=len(train_tokens)):
    correct_predictions += iterate_one_sentence(tokens,label,train_flag=True)
  acc=correct_predictions/len(train_tokens)
  print("\n",acc)
  return acc

def test_model():
  test_correct_predictions=torch.tensor([0.])
  for tokens,label in tqdm(zip(test_tokens,test_labels),total=len(test_tokens)):
    test_correct_predictions += iterate_one_sentence(tokens,label,train_flag=False)
  test_acc=test_correct_predictions/len(test_tokens)
  return test_acc

In [13]:
model=EmbedSumClassify(5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

acc=train_one_epoch()
test_acc=test_model()

100%|██████████| 53879/53879 [01:14<00:00, 719.45it/s]



 tensor([0.7647])


100%|██████████| 13470/13470 [00:02<00:00, 5223.63it/s]


In [14]:
#check on random labels
test_correct_predictions=torch.tensor([0.])
random_labels=torch.rand(len(test_tokens))<0.5
for tokens,label in tqdm(zip(test_tokens,random_labels),total=len(test_tokens)):
  test_correct_predictions += iterate_one_sentence(tokens,label,train_flag=False)
rand_acc=test_correct_predictions/len(test_tokens)

100%|██████████| 13470/13470 [00:02<00:00, 5221.64it/s]


In [15]:
print(acc,test_acc,rand_acc, sep="\n")

tensor([0.7647])
tensor([0.8366])
tensor([0.4970])


In [16]:
preprocess = lambda x: torch.tensor(vocab(x.split()))
example_sentences=["very good , not bad",
                   "very bad , not good"]
with torch.no_grad():                   
  for sent in example_sentences:
    print(preprocess(sent))
    print(torch.exp(model(preprocess(sent))))


tensor([77, 46,  2, 33, 74])
tensor([9.9916e-01, 8.3970e-04])
tensor([77, 74,  2, 33, 46])
tensor([9.9916e-01, 8.3970e-04])
