In [1]:
from preprocessors.sequencers import CustomSequencer, BertSequencer
from trainers.custom_trainer import CustomTrainer
from trainers.bert_trainer import BertTrainer
from trainers.bert_trainer import BertEmailClassifier  # Important to import

In [2]:
sentences = [
    """
From: dr17@crux2.cit.cornell.edu (Dean M Robinson)
Subject: Re: Buying a high speed v.everything modem
Nntp-Posting-Host: crux2.cit.cornell.edu
Organization: Cornell University
Lines: 20

ejbehr@rs6000.cmp.ilstu.edu (Eric Behr) writes:

>Just a quick summary of recent findings re. high speed modems. Top three
>contenders seem to be AT&T Paradyne, ZyXEL, and US Robotics. ZyXEL has the
>biggest "cult following", and can be had for under $300, but I ignored it
>because I need something with Mac software, which will work without any
>tweaking.

You shouldn't have ignored the ZyXEL.  It can be purchased with a "Mac
bundle", which includes a hardware-handshaking cable and FaxSTF software.
The bundle adds between $35 and $60 to the price of the modem, depending
on the supplier.  It is true that the modem has no Mac-specific docs,
but it doesn't require much 'tweaking' (aside from setting &D0 in the
init string, to enable hardware handshaking).

For more information on the ZyXEL, including sources, look at various files
on sumex-aim.stanford.edu, in info-mac/report.

Disclaimer:  I have no affiliation with ZyXEL, though I did buy a ZyXEL
a U1496E modem.

    """,
    """
Subject: Re: Key Registering Bodies
From: a_rubin@dsg4.dse.beckman.com (Arthur Rubin)
Organization: Beckman Instruments, Inc.
Nntp-Posting-Host: dsg4.dse.beckman.com
Lines: 16

In <nagleC5w79E.7HL@netcom.com> nagle@netcom.com (John Nagle) writes:

>       Since the law requires that wiretaps be requested by the Executive
>Branch and approved by the Judicial Branch, it seems clear that one
>of the key registering bodies should be under the control of the
>Judicial Branch.  I suggest the Supreme Court, or, regionally, the
>Courts of Appeal.  More specifically, the offices of their Clerks.

Now THAT makes sense.  But the other half must be in a non-government
escrow.  (I still like EFF, but I admin their security has not been
tested.)

--
Arthur L. Rubin: a_rubin@dsg4.dse.beckman.com (work) Beckman Instruments/Brea
216-5888@mcimail.com 70707.453@compuserve.com arthur@pnet01.cts.com (personal)
My opinions are my own, and do not represent those of my employer.

    """
] + [
    'my macbook heats up too much, but there is no better laptop for 3000 dollars',
    'I am a clear atheist, sometimes, I hear people say they are agnostic, I hate it.'
]

labels = [
    "comp.sys.mac.hardware",
    "sci.crypt"
] + [
    "comp.sys.mac.hardware",
    'alt.atheism'
]

In [3]:
id_to_label = [
    'alt.atheism',
    'comp.graphics',
    'comp.os.ms-windows.misc',
    'comp.sys.ibm.pc.hardware',
    'comp.sys.mac.hardware',
    'comp.windows.x',
    'misc.forsale',
    'rec.autos',
    'rec.motorcycles',
    'rec.sport.baseball',
    'rec.sport.hockey',
    'sci.crypt',
    'sci.electronics',
    'sci.med',
    'sci.space',
    'soc.religion.christian',
    'talk.politics.guns',
    'talk.politics.mideast',
    'talk.politics.misc',
    'talk.religion.misc'
]

In [4]:
sequencer_custom = CustomSequencer()
sequencer_custom.tokenizer = CustomSequencer.load_tokenizer('../preprocessors/custom_tokenizer.json')

In [5]:
sequences_custom = sequencer_custom.make_sequences(sentences)
sequences_custom.shape

(4, 150)

In [6]:
custom_trainer = CustomTrainer('../trainers/models/linear.h5')

In [7]:
prediction_custom = custom_trainer.predict(sequences_custom)

In [8]:
ids = prediction_custom.argmax(axis=1)
print("Predictions\t\tReal\n----------------------------")
for pred_id, real_label in zip(ids, labels):
    print(f'{id_to_label[pred_id]}\t\t{real_label}')

Predictions		Real
----------------------------
comp.sys.mac.hardware		comp.sys.mac.hardware
sci.crypt		sci.crypt
rec.sport.baseball		comp.sys.mac.hardware
soc.religion.christian		alt.atheism


In [10]:
sequencer_bert = BertSequencer()
sequences_bert = sequencer_bert.make_sequences(sentences)
sequences_bert.shape

(4, 512)

In [11]:
bert_trainer = BertTrainer(load_path='../trainers/models/bert_clf.pt')

In [12]:
prediction_bert = bert_trainer.predict(sequences_bert)

In [13]:
ids = prediction_bert.argmax(axis=1)
print("Predictions\t\tReal\n----------------------------")
for pred_id, real_label in zip(ids, labels):
    print(f'{id_to_label[pred_id]}\t\t{real_label}')

Predictions		Real
----------------------------
comp.sys.mac.hardware		comp.sys.mac.hardware
sci.crypt		sci.crypt
misc.forsale		comp.sys.mac.hardware
alt.atheism		alt.atheism


In [14]:
print(prediction_bert)
prediction_bert.argmax(axis=1)

tensor([[2.1233e-03, 1.1182e-03, 6.4410e-04, 2.6203e-03, 9.7992e-01, 9.8657e-04,
         1.7045e-03, 1.1660e-03, 2.2988e-04, 4.8275e-04, 3.7244e-04, 4.8188e-04,
         1.9904e-03, 1.0086e-03, 2.9782e-04, 1.2105e-03, 4.8972e-04, 9.8602e-04,
         1.5422e-03, 6.2061e-04],
        [1.3790e-03, 3.8299e-04, 6.0808e-04, 7.0174e-04, 4.1560e-04, 7.1869e-04,
         1.0412e-03, 7.3274e-04, 7.1932e-04, 9.7821e-04, 1.1487e-03, 9.7853e-01,
         1.2747e-03, 4.6278e-04, 9.7632e-04, 6.2176e-04, 2.8434e-03, 2.1772e-03,
         3.3126e-03, 9.7755e-04],
        [3.5820e-03, 5.7003e-03, 1.4491e-02, 6.5536e-03, 2.9461e-01, 4.3612e-03,
         5.1769e-01, 9.1748e-02, 8.0672e-03, 5.7857e-03, 5.8959e-03, 1.8336e-03,
         3.6739e-03, 9.0327e-03, 4.6993e-03, 6.8493e-03, 2.2674e-03, 3.5573e-03,
         5.5848e-03, 4.0166e-03],
        [7.4543e-01, 8.8457e-03, 9.3976e-03, 2.7785e-03, 2.1870e-02, 4.1538e-03,
         5.2523e-03, 2.3102e-03, 2.3648e-03, 1.0256e-02, 6.8491e-03, 3.5207e-03,
       

tensor([ 4, 11,  6,  0])

In [15]:
print(prediction_custom)
prediction_custom.argmax(axis=1)

[[2.31406568e-19 1.81764483e-15 1.01153441e-08 2.32112171e-07
  9.99999762e-01 1.87728288e-11 9.12295217e-09 4.90418417e-09
  1.78645473e-10 7.20960767e-14 5.00692966e-14 6.47765058e-16
  2.16097344e-08 4.15439760e-15 2.40586034e-13 8.90908013e-18
  2.83749734e-23 8.06993346e-24 2.22078597e-18 4.94397167e-19]
 [3.33759829e-08 1.80159815e-12 6.82441583e-20 1.52260804e-09
  2.45338477e-15 5.17724427e-13 8.81763082e-24 1.84482759e-13
  1.14253829e-09 2.31020724e-23 1.76744965e-16 1.00000000e+00
  1.41780018e-10 7.05825237e-14 5.05246422e-09 9.61604962e-11
  4.46470257e-11 1.07726155e-20 5.90609228e-10 2.10220708e-09]
 [1.32372748e-04 5.92641719e-02 1.89327389e-01 3.00018284e-02
  3.49076353e-02 4.11276380e-03 6.59645051e-02 1.33163810e-01
  3.70249734e-03 2.66374737e-01 1.18337460e-02 2.49315344e-04
  7.81236589e-02 9.97302756e-02 5.63397072e-03 1.12903854e-02
  3.33014294e-03 5.01253700e-04 1.34756102e-03 1.00791466e-03]
 [6.31099218e-04 1.50614930e-02 6.62229955e-02 1.05250254e-02
  4.0

array([ 4, 11,  9, 15])