<a href="https://colab.research.google.com/github/Skander28/Models/blob/main/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
import torch

if torch.cuda.is_available():
  device = torch.device("cuda")
  print('There are %d GPU(s) available.' % torch.cuda.device_count())
  print('We will use the GPU:', torch.cuda.get_device_name(0))
  !nvidia-smi
else:
  print('No GPU available, using the CPU instead.')
  device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla T4
Fri Apr  7 07:47:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P0    28W /  70W |   8477MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------

In [42]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [43]:
!ls /content/drive/MyDrive/

 Classroom	    dialect_dataset.csv   preprocessed_tweets.csv
'Colab Notebooks'   messages.csv	  pre_tweets.csv


In [44]:
import pandas as pd
filtered_df = pd.read_csv('/content/drive/MyDrive/pre_tweets.csv')

In [45]:
filtered_df['dialect'].value_counts()

1    36499
3    36499
2    36499
0    36499
Name: dialect, dtype: int64

In [46]:
filtered_df = filtered_df.dropna()

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from keras.preprocessing.text import Tokenizer
from keras.utils import pad_sequences


In [48]:
# Load data
features = filtered_df.tweets.values
labels = pd.get_dummies(filtered_df['dialect']).values

In [49]:
vocab_size = 20000
max_length= 200
tokenizer = Tokenizer(num_words=vocab_size,filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=False)
tokenizer.fit_on_texts(features)
X = tokenizer.texts_to_sequences(features)


In [50]:
import nltk
from collections import Counter
nltk.download('stopwords')
from nltk.corpus import stopwords

# use the stopwords
stop_words = set(stopwords.words('arabic'))

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [51]:
all_words = []
for tokens in X:
    all_words.extend(tokens)

stop_words = set(stopwords.words('arabic'))
all_words = [word for word in all_words if word not in stop_words]

word_counts = Counter(all_words)
most_common_words = [word for word, count in word_counts.most_common(500)]

def remove_common_words(tokens):
  new_tokens = [token for token in tokens if token not in most_common_words]
  return new_tokens

X = remove_common_words(X)


X = pad_sequences(X, maxlen=max_length)


In [52]:
# Split data into training, validation, and test sets
X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.1, random_state=42, shuffle=True)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42, shuffle=True)


In [53]:
# Convert data to PyTorch tensors
X_train, y_res_ = torch.tensor(X_train), torch.tensor(y_train)
X_val, y_val = torch.tensor(X_val), torch.tensor(y_val)
X_test, y_test = torch.tensor(X_test), torch.tensor(y_test)

In [54]:
class DialectDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [55]:
# Create dataloaders for training, validation, and test sets
train_dataset = DialectDataset(X_train, y_train)
val_dataset = DialectDataset(X_val, y_val)
test_dataset = DialectDataset(X_test, y_test)

In [56]:
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)


# **LSTM Model**

In [57]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np

output_dim = 100
batch_size = 64


class MyModel(nn.Module):
    def __init__(self, vocab_size, output_dim, max_length):
        super(MyModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, output_dim)
        self.dropout = nn.Dropout(p=0.5)
        self.lstm = nn.LSTM(output_dim, 50, dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(50, 25, dropout=0.3, batch_first=True)
        self.linear = nn.Linear(25, 4)
        softmax = nn.Softmax(dim=1)
        self.fc = nn.Sequential(self.linear, softmax)
    
    def forward(self, x):
        embedded = self.embedding(x)
        dropped = self.dropout(embedded)
        lstm_out, _ = self.lstm(dropped)
        lstm2_out, _ = self.lstm2(lstm_out)
        out = self.fc(lstm2_out[:, -1, :])
        return out

In [58]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MyModel(vocab_size, output_dim, max_length).to(device)
optimizer = optim.Adam(model.parameters(),lr=0.01)
criterion = nn.CrossEntropyLoss()
counter_nb = 0
for epoch in tqdm(range(100)):
    model.train()
    counter_nb = counter_nb + 1 
    print(counter_nb)
    running_loss = 0.0
    with torch.cuda.device(0):
      for batch in train_dataloader:
        inputs, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        labels = labels.float()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / (len(train_dataloader))

    best_val_loss = np.inf
    patience = 5
    counter = 0
    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            inputs, labels = batch[0].to(device), batch[1].to(device)
            labels = labels.float()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    val_loss = val_loss / (len(val_dataloader))
    # Check if the validation loss has improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
    else:
        counter += 1

    # Stop the training process if the validation loss hasn't improved for `patience` epochs
    if counter >= patience:
        break
    print(" epoch loss :", epoch_loss , "| val loss :", val_loss)

print("Training stopped after epoch", epoch)
        

  0%|          | 0/100 [00:00<?, ?it/s]

1


  1%|          | 1/100 [00:10<17:21, 10.52s/it]

 epoch loss : 1.1745473210300719 | val loss : 1.0499655364207852
2


  2%|▏         | 2/100 [00:20<17:03, 10.44s/it]

 epoch loss : 1.0593587952581318 | val loss : 1.0167955252730732
3


  3%|▎         | 3/100 [00:31<16:47, 10.39s/it]

 epoch loss : 1.0255870142733896 | val loss : 0.995151860621369
4


  4%|▍         | 4/100 [00:40<16:06, 10.07s/it]

 epoch loss : 1.0069642446580387 | val loss : 0.9860842696671347
5


  5%|▌         | 5/100 [00:51<16:04, 10.15s/it]

 epoch loss : 0.9943246883682875 | val loss : 0.9812363814381719
6


  6%|▌         | 6/100 [01:01<15:58, 10.20s/it]

 epoch loss : 0.9853657896255518 | val loss : 0.9765449658180904
7


  7%|▋         | 7/100 [01:11<15:51, 10.24s/it]

 epoch loss : 0.9794956399893864 | val loss : 0.971653005451832
8


  8%|▊         | 8/100 [01:21<15:22, 10.03s/it]

 epoch loss : 0.9726959767860252 | val loss : 0.9682382296011286
9


  9%|▉         | 9/100 [01:31<15:20, 10.12s/it]

 epoch loss : 0.9680801205885359 | val loss : 0.9626888302923406
10


 10%|█         | 10/100 [01:42<15:20, 10.23s/it]

 epoch loss : 0.9645802862780951 | val loss : 0.9642749611035134
11


 11%|█         | 11/100 [01:52<15:14, 10.27s/it]

 epoch loss : 0.9612923417062986 | val loss : 0.9640112539518226
12


 12%|█▏        | 12/100 [02:02<14:51, 10.13s/it]

 epoch loss : 0.9586001526548232 | val loss : 0.9610658814606158
13


 13%|█▎        | 13/100 [02:12<14:39, 10.11s/it]

 epoch loss : 0.9581010385186641 | val loss : 0.95818331432574
14


 14%|█▍        | 14/100 [02:22<14:35, 10.18s/it]

 epoch loss : 0.9580661141317645 | val loss : 0.9621564705395004
15


 15%|█▌        | 15/100 [02:33<14:29, 10.23s/it]

 epoch loss : 0.9556322418056525 | val loss : 0.9593725641375607
16


 16%|█▌        | 16/100 [02:43<14:16, 10.19s/it]

 epoch loss : 0.9542631957502592 | val loss : 0.9578044698076341
17


 17%|█▋        | 17/100 [02:53<14:01, 10.13s/it]

 epoch loss : 0.9530955835596308 | val loss : 0.9554968855334717
18


 18%|█▊        | 18/100 [03:03<13:56, 10.21s/it]

 epoch loss : 0.9502769733468691 | val loss : 0.9540717665431568
19


 19%|█▉        | 19/100 [03:13<13:50, 10.25s/it]

 epoch loss : 0.9510467935885701 | val loss : 0.9550107161975602
20


 20%|██        | 20/100 [03:24<13:39, 10.24s/it]

 epoch loss : 0.9503671103419163 | val loss : 0.9541240872688663
21


 21%|██        | 21/100 [03:33<13:16, 10.08s/it]

 epoch loss : 0.948116287847102 | val loss : 0.9596839416952967
22


 22%|██▏       | 22/100 [03:44<13:14, 10.19s/it]

 epoch loss : 0.946734817255111 | val loss : 0.9527286635440531
23


 23%|██▎       | 23/100 [03:54<13:09, 10.25s/it]

 epoch loss : 0.9447302335139477 | val loss : 0.9503198416487685
24


 24%|██▍       | 24/100 [04:05<13:05, 10.33s/it]

 epoch loss : 0.9447770158062766 | val loss : 0.9523685701842447
25


 25%|██▌       | 25/100 [04:14<12:38, 10.11s/it]

 epoch loss : 0.943775971891818 | val loss : 0.9546899905482542
26


 26%|██▌       | 26/100 [04:25<12:35, 10.21s/it]

 epoch loss : 0.9444117753149627 | val loss : 0.9478288258163674
27


 27%|██▋       | 27/100 [04:35<12:28, 10.25s/it]

 epoch loss : 0.9440742691422438 | val loss : 0.951655685612299
28


 28%|██▊       | 28/100 [04:46<12:23, 10.32s/it]

 epoch loss : 0.9425203703763165 | val loss : 0.9496960955337413
29


 29%|██▉       | 29/100 [04:55<12:00, 10.15s/it]

 epoch loss : 0.9426831696585659 | val loss : 0.9513518133209747
30


 30%|███       | 30/100 [05:05<11:50, 10.15s/it]

 epoch loss : 0.9428148488977771 | val loss : 0.9492936724597968
31


 31%|███       | 31/100 [05:16<11:44, 10.21s/it]

 epoch loss : 0.9449285512298217 | val loss : 0.9496983625356433
32


 32%|███▏      | 32/100 [05:26<11:37, 10.26s/it]

 epoch loss : 0.944290433095131 | val loss : 0.949182652732701
33


 33%|███▎      | 33/100 [05:36<11:24, 10.21s/it]

 epoch loss : 0.9427621634878638 | val loss : 0.9480146404030254
34


 34%|███▍      | 34/100 [05:46<11:11, 10.17s/it]

 epoch loss : 0.9423725783695907 | val loss : 0.9507885254123836
35


 35%|███▌      | 35/100 [05:57<11:04, 10.23s/it]

 epoch loss : 0.9414602112653968 | val loss : 0.9508021537540028
36


 36%|███▌      | 36/100 [06:07<10:56, 10.26s/it]

 epoch loss : 0.9430154926784626 | val loss : 0.9516151093163536
37


 37%|███▋      | 37/100 [06:17<10:45, 10.25s/it]

 epoch loss : 0.942702701016938 | val loss : 0.9498400265730701
38


 38%|███▊      | 38/100 [06:27<10:25, 10.09s/it]

 epoch loss : 0.9448508209351337 | val loss : 0.9470200069899698
39


 39%|███▉      | 39/100 [06:37<10:20, 10.17s/it]

 epoch loss : 0.941611483435094 | val loss : 0.9486444465164999
40


 40%|████      | 40/100 [06:48<10:16, 10.27s/it]

 epoch loss : 0.9409629757399167 | val loss : 0.9505653667797163
41


 41%|████      | 41/100 [06:58<10:11, 10.37s/it]

 epoch loss : 0.9420668359571721 | val loss : 0.9486814167314362
42


 42%|████▏     | 42/100 [07:08<09:48, 10.15s/it]

 epoch loss : 0.9415251300125927 | val loss : 0.9469484166034217
43


 43%|████▎     | 43/100 [07:19<09:46, 10.28s/it]

 epoch loss : 0.9401532399964023 | val loss : 0.9471385299580769
44


 44%|████▍     | 44/100 [07:29<09:36, 10.30s/it]

 epoch loss : 0.9399626796250735 | val loss : 0.9448915448003602
45


 45%|████▌     | 45/100 [07:39<09:27, 10.31s/it]

 epoch loss : 0.9404084278972118 | val loss : 0.9449109530564651
46


 46%|████▌     | 46/100 [07:49<09:11, 10.21s/it]

 epoch loss : 0.9398539188555824 | val loss : 0.9468990665616341
47


 47%|████▋     | 47/100 [07:59<08:59, 10.18s/it]

 epoch loss : 0.9385161951055259 | val loss : 0.9450976999060622
48


 48%|████▊     | 48/100 [08:10<08:52, 10.23s/it]

 epoch loss : 0.9390409799107226 | val loss : 0.9461234420248605
49


 49%|████▉     | 49/100 [08:20<08:44, 10.29s/it]

 epoch loss : 0.9388238032916923 | val loss : 0.9464890372405932
50


 50%|█████     | 50/100 [08:30<08:31, 10.24s/it]

 epoch loss : 0.937748001960965 | val loss : 0.9466953228399592
51


 51%|█████     | 51/100 [08:40<08:15, 10.12s/it]

 epoch loss : 0.936910456328681 | val loss : 0.945105975692712
52


 52%|█████▏    | 52/100 [08:51<08:09, 10.20s/it]

 epoch loss : 0.9378810720719817 | val loss : 0.9451223811478291
53


 53%|█████▎    | 53/100 [09:01<08:02, 10.26s/it]

 epoch loss : 0.936399828787748 | val loss : 0.946478238672886
54


 54%|█████▍    | 54/100 [09:11<07:52, 10.27s/it]

 epoch loss : 0.9369523715211715 | val loss : 0.9490188004322422
55


 55%|█████▌    | 55/100 [09:21<07:34, 10.10s/it]

 epoch loss : 0.9360587145342971 | val loss : 0.9480029845121994
56


 56%|█████▌    | 56/100 [09:31<07:28, 10.19s/it]

 epoch loss : 0.9357574177639825 | val loss : 0.9494900683176171
57


 57%|█████▋    | 57/100 [09:42<07:21, 10.28s/it]

 epoch loss : 0.9373757075760272 | val loss : 0.9484758076158543
58


 58%|█████▊    | 58/100 [09:52<07:12, 10.31s/it]

 epoch loss : 0.9351325901426795 | val loss : 0.9493293574134123
59


 59%|█████▉    | 59/100 [10:02<06:53, 10.08s/it]

 epoch loss : 0.9354699009553694 | val loss : 0.9455777774158033
60


 60%|██████    | 60/100 [10:12<06:45, 10.15s/it]

 epoch loss : 0.9357641749890335 | val loss : 0.9445337806511851
61


 61%|██████    | 61/100 [10:22<06:38, 10.22s/it]

 epoch loss : 0.93457035349432 | val loss : 0.9494896290950405
62


 62%|██████▏   | 62/100 [10:33<06:29, 10.25s/it]

 epoch loss : 0.9360685054680485 | val loss : 0.9497129234295447
63


 63%|██████▎   | 63/100 [10:43<06:16, 10.19s/it]

 epoch loss : 0.934419543163859 | val loss : 0.9481094725502347
64


 64%|██████▍   | 64/100 [10:53<06:06, 10.19s/it]

 epoch loss : 0.9360054864253833 | val loss : 0.9440187389410816
65


 65%|██████▌   | 65/100 [11:03<05:59, 10.27s/it]

 epoch loss : 0.9371381477811636 | val loss : 0.948432315032459
66


 66%|██████▌   | 66/100 [11:14<05:49, 10.29s/it]

 epoch loss : 0.9356271674880734 | val loss : 0.94782409534871
67


 67%|██████▋   | 67/100 [11:24<05:38, 10.26s/it]

 epoch loss : 0.9349749275358208 | val loss : 0.9436551112573124
68


 68%|██████▊   | 68/100 [11:34<05:24, 10.13s/it]

 epoch loss : 0.934932989314263 | val loss : 0.949478421975108
69


 69%|██████▉   | 69/100 [11:44<05:16, 10.22s/it]

 epoch loss : 0.9351139770764293 | val loss : 0.9462913514919651
70


 70%|███████   | 70/100 [11:55<05:08, 10.29s/it]

 epoch loss : 0.9350860584607888 | val loss : 0.9432463622787624
71


 71%|███████   | 71/100 [12:05<04:58, 10.31s/it]

 epoch loss : 0.9341629786795868 | val loss : 0.9450512386641456
72


 72%|███████▏  | 72/100 [12:15<04:42, 10.10s/it]

 epoch loss : 0.9395907821851376 | val loss : 0.9535572618535422
73


 73%|███████▎  | 73/100 [12:25<04:35, 10.19s/it]

 epoch loss : 0.9366904049789234 | val loss : 0.9497213687711549
74


 74%|███████▍  | 74/100 [12:35<04:25, 10.22s/it]

 epoch loss : 0.9385312602762536 | val loss : 0.9511666679845273
75


 75%|███████▌  | 75/100 [12:46<04:16, 10.27s/it]

 epoch loss : 0.938442466449944 | val loss : 0.9546930671316906
76


 76%|███████▌  | 76/100 [12:55<04:01, 10.07s/it]

 epoch loss : 0.9370404765074387 | val loss : 0.9542250274454506
77


 77%|███████▋  | 77/100 [13:06<03:54, 10.18s/it]

 epoch loss : 0.937120487099086 | val loss : 0.9484545177626378
78


 78%|███████▊  | 78/100 [13:16<03:44, 10.22s/it]

 epoch loss : 0.9366713267190632 | val loss : 0.9500810289845883
79


 79%|███████▉  | 79/100 [13:27<03:35, 10.28s/it]

 epoch loss : 0.9368904307484627 | val loss : 0.9496269425720845
80


 80%|████████  | 80/100 [13:36<03:22, 10.15s/it]

 epoch loss : 0.9375087912658077 | val loss : 0.9496188184011329
81


 81%|████████  | 81/100 [13:47<03:12, 10.16s/it]

 epoch loss : 0.9375376762314276 | val loss : 0.9491273890999914
82


 82%|████████▏ | 82/100 [13:57<03:03, 10.21s/it]

 epoch loss : 0.9380007511351531 | val loss : 0.9481556890080276
83


 83%|████████▎ | 83/100 [14:07<02:54, 10.25s/it]

 epoch loss : 0.9363636347380552 | val loss : 0.9491598400676134
84


 84%|████████▍ | 84/100 [14:17<02:42, 10.16s/it]

 epoch loss : 0.9369994766074857 | val loss : 0.9467199161214735
85


 85%|████████▌ | 85/100 [14:27<02:31, 10.09s/it]

 epoch loss : 0.9356024971643051 | val loss : 0.9489248594612751
86


 86%|████████▌ | 86/100 [14:37<02:22, 10.16s/it]

 epoch loss : 0.9374392822191313 | val loss : 0.9463613415227353
87


 87%|████████▋ | 87/100 [14:48<02:12, 10.22s/it]

 epoch loss : 0.9375241272377245 | val loss : 0.9476333445715672
88


 88%|████████▊ | 88/100 [14:58<02:02, 10.22s/it]

 epoch loss : 0.9375004453253952 | val loss : 0.9461493963755451
89


 89%|████████▉ | 89/100 [15:08<01:50, 10.08s/it]

 epoch loss : 0.9375495931286832 | val loss : 0.9457729021901066
90


 90%|█████████ | 90/100 [15:18<01:41, 10.15s/it]

 epoch loss : 0.9386647431687876 | val loss : 0.9481280156709615
91


 91%|█████████ | 91/100 [15:28<01:31, 10.21s/it]

 epoch loss : 0.9366788609009801 | val loss : 0.9479901741430597
92


 92%|█████████▏| 92/100 [15:39<01:21, 10.24s/it]

 epoch loss : 0.9534262674517961 | val loss : 0.9565581774827346
93


 93%|█████████▎| 93/100 [15:48<01:10, 10.09s/it]

 epoch loss : 0.9500784974916157 | val loss : 0.9532144590488916
94


 94%|█████████▍| 94/100 [15:59<01:01, 10.17s/it]

 epoch loss : 0.9482360932153541 | val loss : 0.9528376396419933
95


 95%|█████████▌| 95/100 [16:09<00:51, 10.22s/it]

 epoch loss : 0.9444735265655435 | val loss : 0.9511156241291935
96


 96%|█████████▌| 96/100 [16:20<00:41, 10.27s/it]

 epoch loss : 0.944886843524712 | val loss : 0.9636534452438354
97


 97%|█████████▋| 97/100 [16:29<00:30, 10.06s/it]

 epoch loss : 0.944381918354984 | val loss : 0.9509373807791367
98


 98%|█████████▊| 98/100 [16:39<00:20, 10.15s/it]

 epoch loss : 0.9469316765601501 | val loss : 0.966179654436204
99


 99%|█████████▉| 99/100 [16:50<00:10, 10.21s/it]

 epoch loss : 0.9450009715182957 | val loss : 0.9580898814409682
100


100%|██████████| 100/100 [17:00<00:00, 10.21s/it]

 epoch loss : 0.945556880011187 | val loss : 0.9551149699294451
Training stopped after epoch 99





In [60]:
torch.save(model.state_dict(), 'model.pth')

# **Evaluation**

In [62]:
model = MyModel(vocab_size, output_dim, max_length)
model.load_state_dict(torch.load("/content/model.pth", map_location=torch.device('cpu')))

model.eval()
X_test = X_test.to("cpu")
# Compute the model's prediction for the padded sequence
with torch.no_grad():
    pred_ = model(X_test)

In [63]:
pred__ = pred_.cpu().numpy()
preds = np.argmax(pred__, axis=1)

In [64]:
def one_hot(a, num_classes):
  return np.squeeze(np.eye(num_classes)[a.reshape(-1)])
pred_hot = one_hot(preds,4)

In [65]:
x_np = torch.from_numpy(pred_hot)

In [66]:
from sklearn.metrics import accuracy_score,f1_score,precision_score,recall_score
recall_score(y_test, x_np,average='macro') 

0.7877673672498589

In [67]:
precision_score(y_test, x_np,average='macro') 

0.7878267057080199

In [68]:
accuracy_score(y_test, pred_hot)

0.7881321090859257

In [69]:
f1_score(y_test, x_np,average='macro')

0.7874140608853802

# **Testing**

In [76]:
# Set the model to evaluation mode
model.eval()
device = "cuda"
# Tokenize the new complaint and pad the sequence
new_complaint = ['واش دير للعشى']
#new_complaint  = ["شوكران علا هاد "]
#new_complaint  = ["نبي نروح للحوش"]
#new_complaint  = ["شبيك شتحب "]
seq = tokenizer.texts_to_sequences(new_complaint)
padded = pad_sequences(seq, maxlen=max_length)

# Convert the padded sequence to a PyTorch tensor and move it to the device (e.g., GPU) if available
padded_tensor = torch.LongTensor(padded).to("cpu")

# Compute the model's prediction for the padded sequence
with torch.no_grad():
    pred = model(padded_tensor)

# Move the prediction back to the CPU and convert to a numpy array
pred = pred.cpu().numpy()

# Map the prediction to a class label using the CLASS_DICT{'DZ': 0, 'LY': 1, 'MA': 2, 'TN': 3}
CLASS_DICT = {1: "LY", 3: "TN", 2: "MA", 0: "DZ" }
class_label = CLASS_DICT[np.argmax(pred)]

# Print the prediction and the predicted class label
print(class_label)

DZ
