In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import string
import torch
import torch.nn as nn

In [2]:
splits = {'train': 'plain_text/train-00000-of-00001.parquet', 'test': 'plain_text/test-00000-of-00001.parquet', 'unsupervised': 'plain_text/unsupervised-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/stanfordnlp/imdb/" + splits["train"])

In [3]:
df.head(10)

Unnamed: 0,text,label
0,I rented I AM CURIOUS-YELLOW from my video sto...,0
1,"""I Am Curious: Yellow"" is a risible and preten...",0
2,If only to avoid making this type of film in t...,0
3,This film was probably inspired by Godard's Ma...,0
4,"Oh, brother...after hearing about this ridicul...",0
5,I would put this at the top of my list of film...,0
6,Whoever wrote the screenplay for this movie ob...,0
7,"When I first saw a glimpse of this movie, I qu...",0
8,"Who are these ""They""- the actors? the filmmake...",0
9,This is said to be a personal film for Peter B...,0


In [4]:
df1 = df[df['label'] == 0]    
df2 = df[df['label'] == 1] 

In [5]:
df1 = df1[:2000]
df2 = df2[:2000]

In [6]:
df = pd.concat([df1, df2], ignore_index=True)
df.head()

Unnamed: 0,text,label
0,I rented I AM CURIOUS-YELLOW from my video sto...,0
1,"""I Am Curious: Yellow"" is a risible and preten...",0
2,If only to avoid making this type of film in t...,0
3,This film was probably inspired by Godard's Ma...,0
4,"Oh, brother...after hearing about this ridicul...",0


In [7]:
data = df['text'].tolist()
data[:3]

['I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, e

In [8]:
s = pd.Series(data)
s[:3]

0    I rented I AM CURIOUS-YELLOW from my video sto...
1    "I Am Curious: Yellow" is a risible and preten...
2    If only to avoid making this type of film in t...
dtype: object

In [9]:
stop_words = set(stopwords.words('english'))
def preprocess(text):
    text = text.lower()
    text = ''.join([word for word in text if word not in string.punctuation])
    tokens = word_tokenize(text)
    tokens = [word for word in tokens if word not in stop_words]
    return ' '.join(tokens)

In [10]:
s_upd = s.apply(preprocess)
s_upd[:3]

0    rented curiousyellow video store controversy s...
1    curious yellow risible pretentious steaming pi...
2    avoid making type film future film interesting...
dtype: object

In [11]:
clean_df = pd.DataFrame(s_upd, columns=['text'])
clean_df['label'] = df['label']
clean_df.head()

Unnamed: 0,text,label
0,rented curiousyellow video store controversy s...,0
1,curious yellow risible pretentious steaming pi...,0
2,avoid making type film future film interesting...,0
3,film probably inspired godards masculin fémini...,0
4,oh brotherafter hearing ridiculous film umptee...,0


In [12]:
data_lst = clean_df['text'].tolist()
data_lst[:3]

['rented curiousyellow video store controversy surrounded first released 1967 also heard first seized us customs ever tried enter country therefore fan films considered controversial really see myselfbr br plot centered around young swedish drama student named lena wants learn everything life particular wants focus attentions making sort documentary average swede thought certain political issues vietnam war race issues united states asking politicians ordinary denizens stockholm opinions politics sex drama teacher classmates married menbr br kills curiousyellow 40 years ago considered pornographic really sex nudity scenes far even shot like cheaply made porno countrymen mind find shocking reality sex nudity major staple swedish cinema even ingmar bergman arguably answer good old boy john ford sex scenes filmsbr br commend filmmakers fact sex shown film shown artistic purposes rather shock people make money shown pornographic theaters america curiousyellow good film anyone wanting study

In [13]:
vocab = list(set(word for sentence in data_lst for word in sentence.split()))
vocab[:3]

['finally', 'beals', 'hautefeuille']

In [14]:
word_to_index = {word: i for i, word in enumerate(vocab)}
items = list(word_to_index.items())
items[:10]

[('finally', 0),
 ('beals', 1),
 ('hautefeuille', 2),
 ('disgruntled', 3),
 ('filmsas', 4),
 ('lesser', 5),
 ('viceterminatrix', 6),
 ('impeding', 7),
 ('nerd', 8),
 ('istanbul', 9)]

In [15]:
sentences_indices = [[word_to_index[word] for word in sentence.split()] for sentence in data_lst]
sentences_indices[:3]

[[6815,
  38277,
  31399,
  16713,
  30257,
  1045,
  24471,
  14145,
  27917,
  22722,
  1976,
  24471,
  20039,
  26908,
  13566,
  37078,
  1118,
  28986,
  1923,
  21483,
  4065,
  25452,
  41063,
  28621,
  2577,
  1706,
  17836,
  27970,
  26538,
  32363,
  18012,
  6860,
  40361,
  551,
  2709,
  41664,
  31148,
  35220,
  18673,
  25394,
  27092,
  7683,
  35220,
  31502,
  19236,
  13201,
  25254,
  13026,
  23890,
  15284,
  29755,
  5947,
  29652,
  35098,
  30829,
  33308,
  34783,
  35098,
  11858,
  15852,
  2653,
  40582,
  8933,
  11565,
  2284,
  6509,
  7797,
  9286,
  551,
  25529,
  42038,
  37841,
  16775,
  27970,
  20860,
  38277,
  24313,
  2103,
  21307,
  41063,
  9620,
  2577,
  9286,
  1833,
  34116,
  805,
  22371,
  42338,
  37199,
  2735,
  21131,
  20761,
  8120,
  1315,
  11266,
  15142,
  42035,
  9286,
  1833,
  32015,
  212,
  40361,
  35737,
  22371,
  10338,
  71,
  39647,
  30318,
  17503,
  40964,
  25920,
  7663,
  2052,
  9286,
  34116,
  30156

In [16]:
vocab_size = len(vocab)
embedding_dim = 7

In [17]:
class Network(nn.Module):
    def __init__(self, voc_size, embed_size):
        super(Network, self).__init__()
        self.embedding = nn.Embedding(voc_size, embed_size) 
        
    def forward(self, inputs):                   
        embeds = self.embedding(inputs)           
        return embeds

In [18]:
model = Network(vocab_size, embedding_dim)

word_emb = model.embedding(torch.tensor(word_to_index["balding"], dtype=torch.long))
print(word_emb.shape)
print(word_emb)

emb_array = word_emb.detach().numpy()
emb_array

torch.Size([7])
tensor([ 1.1985,  1.2901, -0.2415,  0.6737, -0.3497, -0.2661, -0.8989],
       grad_fn=<EmbeddingBackward0>)


array([ 1.1984642 ,  1.290061  , -0.24146774,  0.6737092 , -0.34974918,
       -0.26610819, -0.8989188 ], dtype=float32)

In [19]:
X_train, X_test, y_train, y_test = train_test_split(clean_df['text'], clean_df['label'], test_size=0.2, random_state=42)

In [20]:
sentences = [sentence.split() for sentence in X_train]

In [21]:
def vectorize(sentence):
    words = sentence.split()
    words_vecs = [model.embedding(torch.tensor(word_to_index[word], dtype=torch.long)) for word in words if word in vocab]

    emb_arrs = [el.detach().numpy() for el in words_vecs]
    
    if len(emb_arrs) == 0:
        return np.zeros(100)
    emb_arrs = np.array(emb_arrs)
    return emb_arrs.mean(axis=0)

In [22]:
X_train = np.array([vectorize(sentence) for sentence in X_train])
X_test = np.array([vectorize(sentence) for sentence in X_test])

In [23]:
clf = LogisticRegression()
clf.fit(X_train, y_train)

In [24]:
y_pred = clf.predict(X_test)
print('Accuracy:', accuracy_score(y_test, y_pred))
print('Precision:', precision_score(y_test, y_pred, pos_label=1))
print('Recall:', recall_score(y_test, y_pred, pos_label=1))
print('F1 score:', f1_score(y_test, y_pred, pos_label=1))

Accuracy: 0.49875
Precision: 0.47784200385356457
Recall: 0.656084656084656
F1 score: 0.5529542920847269
