In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import GloVe

class EmailClassifier(nn.Module):
    def __init__(self, hidden_dim=128, emb_dim=300, dropout=0.2, glove_version='6B'):
        super(EmailClassifier, self).__init__()
        self.glove = GloVe(name='6B', dim=emb_dim)
        self.embedding = nn.Embedding.from_pretrained(self.glove.vectors, freeze=True)
        self.subject_encoder = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.subject_attention = Attention(hidden_dim*2)
        self.body_encoder = nn.LSTM(emb_dim, *hidden_dim, bidirectional=True, batch_first=True)
        self.body_attention = Attention(hidden_dim*2)
        self.shared_dense = nn.Sequential(
                            nn.Linear(hidden_dim*4, hidden_dim),
                            nn.ReLU(),
                            nn.Dropout(dropout),)
        self.category_classifier = nn.Linear(hidden_dim, 3)
        self.sensitivity_classifier = nn.Linear(hidden_dim, 1)

    def forward(self, subject, body):
        subject_emb = self.embedding(subject)
        body_emb = self.embedding(body)
        
        subject_output, _ = self.subject_encoder(subject_emb)
        subject_encoded = self.subject_attention(subject_output)
        
        body_output, _ = self.body_encoder(body_emb)
        body_encoded = self.body_attention(body_output)
        
        combined = torch.cat([subject_encoded, body_encoded], dim=1)
        shared_output = self.shared_dense(combined)
        
        category_logits = self.category_classifier(shared_output)
        sensitivity_logits = self.sensitivity_classifier(shared_output)
        
        return category_logits, sensitivity_logits

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, encoder_outputs):
        attention_weights = F.softmax(self.attention(encoder_outputs), dim=1)
        attention = torch.sum(attention_weights*encoder_outputs, dim=1)
        return attention

class Sensitivity(nn.Module):
    def __init__(self):
        super(Sensitivity, self).__init__()

    def forward(self, category_probs, acad_sensitivity):
        student_mask = category_probs[:, 0].unsqueeze(1)
        corp_mask = category_probs[:, 1].unsqueeze(1)
        acad_mask = category_probs[:, 2].unsqueeze(1)
        sensitivity = corp_mask + acad_mask*acad_sensitivity
        return sensitivity

def loss(category_logits, sensitivity_logits, category_labels, sensitivity_labels, alpha):
    category_loss = F.cross_entropy(category_logits, category_labels)
    ac_mask = (category_labels == 2).float()
    sensitivity_loss = F.binary_cross_entropy_with_logits(sensitivity_logits.squeeze(), sensitivity_labels, reduction='none')
    sensitivity_loss = (sensitivity_loss * ac_mask).mean()
    return alpha * category_loss + (1 - alpha) * sensitivity_loss

def preprocess_text(text, max_length, glove):
    tokens = text.lower().split()  # Simple tokenization
    indices = [glove.stoi.get(token, glove.stoi['<unk>']) for token in tokens]
    if len(indices) < max_length:
        indices += [glove.stoi['<pad>']] * (max_length - len(indices))
    else:
        indices = indices[:max_length]
    return torch.tensor(indices)    
        

ModuleNotFoundError: No module named 'torchtext'