In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = 'cuda'

class Head(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.query = nn.Linear(n_embed,head_size,bias=False)
        self.key = nn.Linear(n_embed,head_size,bias=False)
        self.value = nn.Linear(n_embed,head_size,bias=False)
    def forward(self,x):
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        w = k @ q.transpose(-2,-1)
        w = F.softmax(w,dim=-1)
        out = w @ v
        return out

class MultiHead(nn.Module):
    def __init__(self,head_size,n_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed,n_embed)
    def forward(self,x):
        out = torch.cat([head(x) for head in self.heads],-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed,n_embed),
            nn.ReLU(),
            nn.Linear(n_embed,n_embed),
        )
    def forward(self,x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.multihead = MultiHead(head_size,n_heads)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
    
    def forward(self,x):
        x = self.ln1(x)
        x = x + self.multihead(x)
        x = self.ln2(x)
        x = x + self.ffwd(x)
        return x


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,n_embed)
        self.positional_embedding = nn.Embedding(block_size,n_embed)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layers)])
        self.ln = nn.LayerNorm(n_embed)
        self.cl_head = nn.Sequential(
            nn.Linear(n_embed,2),
        )
    def forward(self,x,targets=None):
        ini_emb = self.embedding(x)
        pos_emb = self.positional_embedding(torch.arange(block_size,device=device))
        x = ini_emb + pos_emb
        x = self.blocks(x)
        x = self.ln(x)
        x = self.cl_head(x)
        if targets != None:
            loss = F.cross_entropy(x,targets)
            return x,loss
        else:
            return x

In [61]:
import pandas as pd

In [62]:
messages = pd.read_csv('spam_ham_dataset.csv')

In [63]:
messages = messages[['label_num','text']]

In [64]:
messages.head()

Unnamed: 0,label_num,text
0,0,Subject: enron methanol ; meter # : 988291\r\n...
1,0,"Subject: hpl nom for january 9 , 2001\r\n( see..."
2,0,"Subject: neon retreat\r\nho ho ho , we ' re ar..."
3,1,"Subject: photoshop , windows , office . cheap ..."
4,0,Subject: re : indian springs\r\nthis deal is t...


In [65]:
import string
def clean_text(text):
    # Convert text to lowercase
    text = text.lower()
    
    # Remove punctuation using string library
    text = text.translate(str.maketrans('', '', string.punctuation))
    
    return text

# Apply the cleaning function to the 'text' column
messages['text'] = messages['text'].apply(clean_text)


In [66]:
messages

Unnamed: 0,label_num,text
0,0,subject enron methanol meter 988291\r\nthis...
1,0,subject hpl nom for january 9 2001\r\n see at...
2,0,subject neon retreat\r\nho ho ho we re aroun...
3,1,subject photoshop windows office cheap mai...
4,0,subject re indian springs\r\nthis deal is to ...
...,...,...
5166,0,subject put the 10 on the ft\r\nthe transport ...
5167,0,subject 3 4 2000 and following noms\r\nhpl c...
5168,0,subject calpine daily gas nomination\r\n\r\n\r...
5169,0,subject industrial worksheets for august 2000 ...


In [68]:
import nltk
from nltk.tokenize import word_tokenize

In [70]:
nltk.download('punkt')

[nltk_data] Error loading punkt: <urlopen error [WinError 10060] A
[nltk_data]     connection attempt failed because the connected party
[nltk_data]     did not properly respond after a period of time, or
[nltk_data]     established connection failed because connected host
[nltk_data]     has failed to respond>


False

In [71]:
messages['tokenized_text'] = messages['text'].apply(lambda x: word_tokenize(x))

LookupError: 
**********************************************************************
  Resource [93mpunkt[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt/english.pickle[0m

  Searched in:
    - 'C:\\Users\\omalv/nltk_data'
    - 'C:\\Users\\omalv\\.conda\\envs\\pytorch\\nltk_data'
    - 'C:\\Users\\omalv\\.conda\\envs\\pytorch\\share\\nltk_data'
    - 'C:\\Users\\omalv\\.conda\\envs\\pytorch\\lib\\nltk_data'
    - 'C:\\Users\\omalv\\AppData\\Roaming\\nltk_data'
    - 'C:\\nltk_data'
    - 'D:\\nltk_data'
    - 'E:\\nltk_data'
    - ''
**********************************************************************


In [None]:
c