# Linear-Log Model 

## 0.导入库

In [1]:
import math
import pickle
import re
from collections import Counter

import nltk
import numpy as np
import pandas as pd
from nltk import pos_tag
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer
from scipy.sparse import csr_matrix
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

In [2]:
nltk.download("wordnet")
nltk.download("omw-1.4")
nltk.download("averaged_perceptron_tagger")

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\ZDF\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\ZDF\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\ZDF\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

## 1.数据导入与预处理

In [3]:
df = pd.read_csv(
    "ag_news_csv/train.csv", header=None, names=["label", "title", "description"]
)

In [3]:
df = pd.read_csv(
    "ag_news_csv/test.csv", header=None, names=["label", "title", "description"]
)

In [4]:
print(df.head())

   label                                              title  \
0      3  Wall St. Bears Claw Back Into the Black (Reuters)   
1      3  Carlyle Looks Toward Commercial Aerospace (Reu...   
2      3    Oil and Economy Cloud Stocks' Outlook (Reuters)   
3      3  Iraq Halts Oil Exports from Main Southern Pipe...   
4      3  Oil prices soar to all-time record, posing new...   

                                         description  
0  Reuters - Short-sellers, Wall Street's dwindli...  
1  Reuters - Private investment firm Carlyle Grou...  
2  Reuters - Soaring crude prices plus worries\ab...  
3  Reuters - Authorities have halted oil export\f...  
4  AFP - Tearaway world oil prices, toppling reco...  


In [5]:
def replace_space(word):
    return re.sub(r"[-\\/&]", " ", word)

In [6]:
df["title"] = df["title"].apply(replace_space)
df["description"] = df["description"].apply(replace_space)

In [7]:
def replace_num(word):
    return re.sub(r"\d+", "<NUM>", word)

In [8]:
df["title"] = df["title"].apply(replace_num)
df["description"] = df["description"].apply(replace_num)

In [9]:
def separate_num(word):
    return re.sub(r"(<NUM>)", r" \1 ", word)

In [10]:
df["title"] = df["title"].apply(separate_num)
df["description"] = df["description"].apply(separate_num)

In [11]:
print(df)

        label                                              title  \
0           3  Wall St. Bears Claw Back Into the Black (Reuters)   
1           3  Carlyle Looks Toward Commercial Aerospace (Reu...   
2           3    Oil and Economy Cloud Stocks' Outlook (Reuters)   
3           3  Iraq Halts Oil Exports from Main Southern Pipe...   
4           3  Oil prices soar to all time record, posing new...   
...       ...                                                ...   
119995      1  Pakistan's Musharraf Says Won't Quit as Army C...   
119996      2                  Renteria signing a top shelf deal   
119997      2                    Saban not going to Dolphins yet   
119998      2                                  Today's NFL games   
119999      2                       Nets get Carter from Raptors   

                                              description  
0       Reuters   Short sellers, Wall Street's dwindli...  
1       Reuters   Private investment firm Carlyle Grou...  
2  

In [12]:
def tokenize(text):
    return text.split()

In [13]:
df["tokens"] = df["title"].apply(tokenize) + df["description"].apply(tokenize)

In [14]:
print(df["tokens"])

0         [Wall, St., Bears, Claw, Back, Into, the, Blac...
1         [Carlyle, Looks, Toward, Commercial, Aerospace...
2         [Oil, and, Economy, Cloud, Stocks', Outlook, (...
3         [Iraq, Halts, Oil, Exports, from, Main, Southe...
4         [Oil, prices, soar, to, all, time, record,, po...
                                ...                        
119995    [Pakistan's, Musharraf, Says, Won't, Quit, as,...
119996    [Renteria, signing, a, top, shelf, deal, Red, ...
119997    [Saban, not, going, to, Dolphins, yet, The, Mi...
119998    [Today's, NFL, games, PITTSBURGH, at, NY, GIAN...
119999    [Nets, get, Carter, from, Raptors, INDIANAPOLI...
Name: tokens, Length: 120000, dtype: object


In [15]:
df.drop("description", axis=1, inplace=True)
df.drop("title", axis=1, inplace=True)

In [16]:
print(df)

        label                                             tokens
0           3  [Wall, St., Bears, Claw, Back, Into, the, Blac...
1           3  [Carlyle, Looks, Toward, Commercial, Aerospace...
2           3  [Oil, and, Economy, Cloud, Stocks', Outlook, (...
3           3  [Iraq, Halts, Oil, Exports, from, Main, Southe...
4           3  [Oil, prices, soar, to, all, time, record,, po...
...       ...                                                ...
119995      1  [Pakistan's, Musharraf, Says, Won't, Quit, as,...
119996      2  [Renteria, signing, a, top, shelf, deal, Red, ...
119997      2  [Saban, not, going, to, Dolphins, yet, The, Mi...
119998      2  [Today's, NFL, games, PITTSBURGH, at, NY, GIAN...
119999      2  [Nets, get, Carter, from, Raptors, INDIANAPOLI...

[120000 rows x 2 columns]


In [17]:
def lower(tokens):
    return [word.lower() for word in tokens]

In [18]:
df["tokens"] = df["tokens"].apply(lower)

In [19]:
print(df["tokens"])

0         [wall, st., bears, claw, back, into, the, blac...
1         [carlyle, looks, toward, commercial, aerospace...
2         [oil, and, economy, cloud, stocks', outlook, (...
3         [iraq, halts, oil, exports, from, main, southe...
4         [oil, prices, soar, to, all, time, record,, po...
                                ...                        
119995    [pakistan's, musharraf, says, won't, quit, as,...
119996    [renteria, signing, a, top, shelf, deal, red, ...
119997    [saban, not, going, to, dolphins, yet, the, mi...
119998    [today's, nfl, games, pittsburgh, at, ny, gian...
119999    [nets, get, carter, from, raptors, indianapoli...
Name: tokens, Length: 120000, dtype: object


In [20]:
def remove_word_suffixes(word):
    if word.endswith("'s"):
        word = word[:-2]
    # elif word.endswith("s"):
    #    word = word[:-1]
    else:
        return re.sub(r'[.,:()\'"?;#$!]', "", word)

In [21]:
def remove_suffixes(tokens):
    return [remove_word_suffixes(word) for word in tokens]

In [22]:
df["tokens"] = df["tokens"].apply(remove_suffixes)

In [23]:
print(df["tokens"])

0         [wall, st, bears, claw, back, into, the, black...
1         [carlyle, looks, toward, commercial, aerospace...
2         [oil, and, economy, cloud, stocks, outlook, re...
3         [iraq, halts, oil, exports, from, main, southe...
4         [oil, prices, soar, to, all, time, record, pos...
                                ...                        
119995    [None, musharraf, says, wont, quit, as, army, ...
119996    [renteria, signing, a, top, shelf, deal, red, ...
119997    [saban, not, going, to, dolphins, yet, the, mi...
119998    [None, nfl, games, pittsburgh, at, ny, giants,...
119999    [nets, get, carter, from, raptors, indianapoli...
Name: tokens, Length: 120000, dtype: object


In [24]:
def remove_stopwords(tokens):
    return [word for word in tokens if (word not in stopwords) and (word is not None)]

In [25]:
with open("stopwords.txt") as file:
    stopwords = file.read().split(",")

In [26]:
print(stopwords)

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'A', 'about', 'above', 'across', 'after', 'again', 'against', 'all', 'almost', 'alone', 'along', 'already', 'also', 'although', 'always', 'am', 'among', 'an', 'and', 'another', 'any', 'anyone', 'anything', 'anywhere', 'are', "aren't", 'around', 'as', 'at', 'b', 'B', 'back', 'be', 'became', 'because', 'become', 'becomes', 'been', 'before', 'behind', 'being', 'below', 'between', 'both', 'but', 'by', 'c', 'C', 'can', 'cannot', "can't", 'could', "couldn't", 'd', 'D', 'did', "didn't", 'do', 'does', "doesn't", 'doing', 'done', "don't", 'down', 'during', 'e', 'E', 'each', 'either', 'enough', 'even', 'ever', 'every', 'everyone', 'everything', 'everywhere', 'f', 'F', 'few', 'find', 'first', 'for', 'four', 'from', 'full', 'further', 'g', 'G', 'get', 'give', 'go', 'h', 'H', 'had', "hadn't", 'has', "hasn't", 'have', "haven't", 'having', 'he', "he'd", "he'll", 'her', 'here', "here's", 'hers', 'herself', "he's", 'him', 'himself', 'his', 'how', 

In [27]:
df["tokens"] = df["tokens"].apply(remove_stopwords)

In [28]:
print(df["tokens"])

0         [wall, st, bears, claw, black, reuters, reuter...
1         [carlyle, looks, commercial, aerospace, reuter...
2         [oil, economy, cloud, stocks, outlook, reuters...
3         [iraq, halts, oil, exports, main, southern, pi...
4         [oil, prices, soar, time, record, posing, new,...
                                ...                        
119995    [musharraf, says, wont, quit, army, chief, kar...
119996    [renteria, signing, top, shelf, deal, red, sox...
119997    [saban, going, dolphins, miami, dolphins, cour...
119998    [nfl, games, pittsburgh, ny, giants, time, <nu...
119999    [nets, carter, raptors, indianapolis, star, vi...
Name: tokens, Length: 120000, dtype: object


In [29]:
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith("J"):
        return wordnet.ADJ
    elif treebank_tag.startswith("V"):
        return wordnet.VERB
    elif treebank_tag.startswith("N"):
        return wordnet.NOUN
    elif treebank_tag.startswith("R"):
        return wordnet.ADV
    else:
        return wordnet.NOUN

In [30]:
def lemmatize_with_pos(tokens):
    pos_tagged = pos_tag(tokens)
    return [
        lemmatizer.lemmatize(token, get_wordnet_pos(pos)) for token, pos in pos_tagged
    ]

In [None]:
lemmatizer = WordNetLemmatizer()

df["tokens"] = df["tokens"].apply(lemmatize_with_pos)

In [None]:
print(df["tokens"])

In [33]:
df.to_pickle("processed_train_data.pkl")

In [33]:
df.to_pickle("processed_test_data.pkl")

## 2.TF-IDF编码

In [2]:
df = pd.read_pickle("processed_train_data.pkl")

In [52]:
df = pd.read_pickle("processed_test_data.pkl")

In [3]:
print(df)

        label                                             tokens
0           3  [wall, st, bear, claw, black, reuters, reuters...
1           3  [carlyle, look, commercial, aerospace, reuters...
2           3  [oil, economy, cloud, stock, outlook, reuters,...
3           3  [iraq, halt, oil, export, main, southern, pipe...
4           3  [oil, price, soar, time, record, pose, new, me...
...       ...                                                ...
119995      1  [musharraf, say, wont, quit, army, chief, kara...
119996      2  [renteria, sign, top, shelf, deal, red, sox, g...
119997      2  [saban, go, dolphin, miami, dolphin, courtship...
119998      2  [nfl, game, pittsburgh, ny, giant, time, <num>...
119999      2  [net, carter, raptor, indianapolis, star, vinc...

[120000 rows x 2 columns]


In [4]:
words_counter = Counter()
for tokens in df["tokens"]:
    words_counter.update(tokens)
vocabulary = dict(words_counter)

In [5]:
counter = 0
for key, value in vocabulary.items():
    if counter < 20:
        print(f"{key}: {value}")
        counter += 1
    else:
        break

wall: 1500
st: 1679
bear: 722
claw: 36
black: 836
reuters: 17270
short: 924
seller: 105
dwindle: 48
band: 240
ultra: 81
cynic: 6
see: 1861
green: 864
carlyle: 16
look: 2786
commercial: 541
aerospace: 129
private: 721
investment: 986


In [6]:
def compute_tf(tokens):
    tf = Counter(tokens)
    for i in tf:
        tf[i] = (1 + math.log10(tf[i])) if tf[i] != 0 else 0
    return dict(tf)

In [7]:
TF = [compute_tf(tokens) for tokens in df["tokens"]]

In [8]:
counter = 0
for i in TF:
    if counter < 10:
        print(i)
        counter += 1
    else:
        break

{'wall': 1.3010299956639813, 'st': 1.0, 'bear': 1.0, 'claw': 1.0, 'black': 1.0, 'reuters': 1.3010299956639813, 'short': 1.0, 'seller': 1.0, 'dwindle': 1.0, 'band': 1.0, 'ultra': 1.0, 'cynic': 1.0, 'see': 1.0, 'green': 1.0}
{'carlyle': 1.3010299956639813, 'look': 1.0, 'commercial': 1.0, 'aerospace': 1.0, 'reuters': 1.3010299956639813, 'private': 1.0, 'investment': 1.0, 'firm': 1.0, 'group': 1.0, 'reputation': 1.0, 'make': 1.0, 'time': 1.0, 'occasionally': 1.0, 'controversial': 1.0, 'play': 1.0, 'defense': 1.0, 'industry': 1.0, 'quietly': 1.0, 'place': 1.0, 'bet': 1.0, 'market': 1.0}
{'oil': 1.0, 'economy': 1.3010299956639813, 'cloud': 1.0, 'stock': 1.3010299956639813, 'outlook': 1.3010299956639813, 'reuters': 1.3010299956639813, 'soar': 1.0, 'crude': 1.0, 'price': 1.0, 'plus': 1.0, 'worry': 1.0, 'earnings': 1.0, 'expect': 1.0, 'hang': 1.0, 'market': 1.0, 'week': 1.0, 'depth': 1.0, 'summer': 1.0, 'doldrums': 1.0}
{'iraq': 1.3010299956639813, 'halt': 1.3010299956639813, 'oil': 1.477121254

In [9]:
def compute_idf(dft, df_tokens_len):
    return math.log10(df_tokens_len / dft)

In [10]:
IDF = {word: compute_idf(dft, len(df["tokens"])) for word, dft in vocabulary.items()}

In [12]:
with open("IDF.pkl", "wb") as f:
    pickle.dump(IDF, f)

In [13]:
with open("IDF.pkl", "rb") as f:
    IDF = pickle.load(f)

In [58]:
counter = 0
for key, value in IDF.items():
    if counter < 20:
        print(f"{key}: {value}")
        counter += 1
    else:
        break

wall: 0.7047223332251101
st: 0.6557628961427425
bear: 1.0222763947111522
claw: 2.324511091513504
black: 0.958607314841775
reuters: -0.35647874528666745
short: 0.9151416210606848
seller: 1.8596242932108533
dwindle: 2.199572354905204
band: 1.5006023505691855
ultra: 1.9723285734021416
cynic: 3.102662341897148
see: 0.6110672191500244
green: 0.9442998498018981
carlyle: 2.6766936096248664
look: 0.4358324801928467
commercial: 1.147616327174222
aerospace: 1.7702238819815423
private: 1.0228783275613624
investment: 0.8869366773395801


In [59]:
data = []
indices = []
indptr = [0]

In [60]:
word_list = list(IDF.keys())
word_to_index = {word: i for i, word in enumerate(word_list)}

In [61]:
for i in range(len(TF)):
    for word, tf in TF[i].items():
        if word in IDF:
            tf_idf = tf * IDF[word]
            data.append(tf_idf)
            indices.append(word_to_index[word])
    indptr.append(len(data))

In [62]:
tf_idf = csr_matrix((data, indices, indptr), shape=(len(TF), len(IDF)), dtype=float)

In [76]:
print(tf_idf)

  (0, 2206)	0.773264462536105
  (0, 6322)	1.114400745168392
  (0, 239)	0.40763442637320413
  (0, 4888)	0.5077176052020644
  (0, 3694)	1.4109915763026284
  (0, 3102)	0.6387693529112404
  (0, 17148)	2.0814730428272097
  (0, 21062)	3.4036923375611288
  (0, 66)	-0.5817491067734005
  (0, 6101)	2.132625565274591
  (0, 7934)	2.1648102486459924
  (0, 1504)	1.222802195623679
  (0, 20)	0.484091313777018
  (0, 1095)	0.4618497615771687
  (0, 2752)	2.227601078505448
  (1, 1493)	0.6043517881075472
  (1, 324)	0.24707735517949558
  (1, 18)	1.0228783275613624
  (1, 918)	0.22131044093316934
  (1, 1319)	0.2772362241293246
  (1, 961)	0.46101357765046674
  (1, 3528)	1.5524990983818976
  (1, 1201)	0.8378445188876112
  (1, 1514)	2.3493346752385365
  (1, 1214)	1.831924578604543
  :	:
  (7598, 2355)	1.211496711714679
  (7598, 2056)	0.28487668601161775
  (7598, 583)	0.042972730625268726
  (7598, 318)	0.5600443639421049
  (7598, 9286)	1.9265710828414664
  (7598, 1009)	0.8402112521667182
  (7598, 1532)	0.45805965

In [64]:
print(tf_idf.shape)

(7600, 58148)


## 3.构建并训练Log-Linear模型

In [70]:
class LogLinearModel:
    def __init__(self, n_features, n_classes):
        self.n_features = n_features
        self.n_classes = n_classes
        self.weights = np.zeros((n_classes, n_features))

    def train(self, X, y, lr=0.01, epochs=20, batch_size=64):
        n_samples = X.shape[0]
        for epoch in tqdm(range(epochs)):
            batch_losses = []
            shuffled_indices = np.random.permutation(n_samples)
            for start_index in tqdm(range(0, n_samples, batch_size), leave=False):
                end_index = min(start_index + batch_size, n_samples)
                batch_indices = shuffled_indices[start_index:end_index]

                batch_X = X[batch_indices].toarray()
                batch_y = y[batch_indices]

                scores = batch_X.dot(self.weights.T)
                probs = self._softmax(scores)

                loss = self._cross_entropy(probs, batch_y)
                batch_losses.append(loss)

                delta = (probs - batch_y).T.dot(batch_X)
                self.weights -= lr * delta

            epoch_loss = np.mean(batch_losses)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}")

    def predict(self, X):
        scores = X.dot(self.weights.T)
        probs = self._softmax(scores)
        return np.argmax(probs, axis=1)

    def _softmax(self, x):
        exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        return exp_x / exp_x.sum(axis=1, keepdims=True)

    def _cross_entropy(self, probs, y_true):
        log_probs = -np.log(probs[range(len(probs)), np.argmax(y_true, axis=1)])
        return np.mean(log_probs)

    def save(self, filepath):
        with open(filepath, "wb") as f:
            pickle.dump(self, f)

    @staticmethod
    def load(filepath):
        with open(filepath, "rb") as f:
            return pickle.load(f)

In [73]:
y = np.eye(4)[df["label"] - 1]

In [74]:
print(y)

[[0. 0. 1. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 ...
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 1. 0.]]


In [100]:
X_train, X_val, y_train, y_val = train_test_split(
    tf_idf, y, test_size=0.2, random_state=42
)

In [101]:
model = LogLinearModel(X_train.shape[1], 4)

In [105]:
model.train(X_train, y_train, epochs=100)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 1/100, Loss: 0.0375


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 2/100, Loss: 0.0372


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 3/100, Loss: 0.0367


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 4/100, Loss: 0.0365


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 5/100, Loss: 0.0362


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 6/100, Loss: 0.0357


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 7/100, Loss: 0.0354


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 8/100, Loss: 0.0351


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 9/100, Loss: 0.0348


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 10/100, Loss: 0.0345


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 11/100, Loss: 0.0342


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 12/100, Loss: 0.0339


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 13/100, Loss: 0.0336


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 14/100, Loss: 0.0334


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 15/100, Loss: 0.0330


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 16/100, Loss: 0.0327


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 17/100, Loss: 0.0325


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 18/100, Loss: 0.0321


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 19/100, Loss: 0.0320


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 20/100, Loss: 0.0317


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 21/100, Loss: 0.0315


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 22/100, Loss: 0.0313


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 23/100, Loss: 0.0309


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 24/100, Loss: 0.0308


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 25/100, Loss: 0.0306


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 26/100, Loss: 0.0303


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 27/100, Loss: 0.0301


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 28/100, Loss: 0.0299


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 29/100, Loss: 0.0297


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 30/100, Loss: 0.0295


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 31/100, Loss: 0.0294


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 32/100, Loss: 0.0290


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 33/100, Loss: 0.0288


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 34/100, Loss: 0.0287


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 35/100, Loss: 0.0285


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 36/100, Loss: 0.0284


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 37/100, Loss: 0.0281


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 38/100, Loss: 0.0279


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 39/100, Loss: 0.0278


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 40/100, Loss: 0.0276


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 41/100, Loss: 0.0274


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 42/100, Loss: 0.0273


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 43/100, Loss: 0.0271


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 44/100, Loss: 0.0270


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 45/100, Loss: 0.0268


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 46/100, Loss: 0.0266


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 47/100, Loss: 0.0264


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 48/100, Loss: 0.0264


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 49/100, Loss: 0.0262


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 50/100, Loss: 0.0260


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 51/100, Loss: 0.0258


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 52/100, Loss: 0.0258


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 53/100, Loss: 0.0256


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 54/100, Loss: 0.0255


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 55/100, Loss: 0.0253


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 56/100, Loss: 0.0252


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 57/100, Loss: 0.0251


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 58/100, Loss: 0.0249


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 59/100, Loss: 0.0248


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 60/100, Loss: 0.0246


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 61/100, Loss: 0.0245


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 62/100, Loss: 0.0244


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 63/100, Loss: 0.0242


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 64/100, Loss: 0.0242


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 65/100, Loss: 0.0240


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 66/100, Loss: 0.0239


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 67/100, Loss: 0.0238


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 68/100, Loss: 0.0238


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 69/100, Loss: 0.0237


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 70/100, Loss: 0.0236


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 71/100, Loss: 0.0234


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 72/100, Loss: 0.0232


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 73/100, Loss: 0.0232


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 74/100, Loss: 0.0231


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 75/100, Loss: 0.0230


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 76/100, Loss: 0.0228


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 77/100, Loss: 0.0227


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 78/100, Loss: 0.0226


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 79/100, Loss: 0.0225


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 80/100, Loss: 0.0224


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 81/100, Loss: 0.0224


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 82/100, Loss: 0.0222


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 83/100, Loss: 0.0222


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 84/100, Loss: 0.0220


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 85/100, Loss: 0.0221


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 86/100, Loss: 0.0219


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 87/100, Loss: 0.0217


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 88/100, Loss: 0.0217


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 89/100, Loss: 0.0216


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 90/100, Loss: 0.0215


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 91/100, Loss: 0.0214


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 92/100, Loss: 0.0213


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 93/100, Loss: 0.0213


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 94/100, Loss: 0.0211


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 95/100, Loss: 0.0211


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 96/100, Loss: 0.0209


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 97/100, Loss: 0.0209


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 98/100, Loss: 0.0208


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 99/100, Loss: 0.0208


  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 100/100, Loss: 0.0207


In [106]:
model.save("model.pkl")

## 4.评估与测试

In [None]:
def f1_score(y_true, y_pred):
    precisions = []
    recalls = []
    for label in np.unique(y_true):
        tp = np.sum((y_true == label) & (y_pred == label))
        fp = np.sum((y_true != label) & (y_pred == label))
        fn = np.sum((y_true == label) & (y_pred != label))

        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0

        precisions.append(precision)
        recalls.append(recall)

    f1_scores = []
    for precision, recall in zip(precisions, recalls):
        f1 = (
            2 * (precision * recall) / (precision + recall)
            if precision + recall > 0
            else 0
        )
        f1_scores.append(f1)

    weights = np.bincount(y_true) / len(y_true)
    weighted_f1 = np.sum(f1_scores * weights)

    return weighted_f1

In [71]:
loaded_model = LogLinearModel.load("model.pkl")

In [None]:
val_predictions = loaded_model.predict(X_val)
val_accuracy = np.mean(val_predictions == np.argmax(y_val, axis=1))
print(f"Validation Accuracy: {val_accuracy:.4f}")

In [112]:
f1 = f1_score(np.argmax(y_val, axis=1), val_predictions, average="weighted")
print(f"F1 Score: {f1:.4f}")

F1 Score: 0.8733


In [113]:
print("Classification Report:")
print(classification_report(np.argmax(y_val, axis=1), val_predictions))

Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.87      0.87      5956
           1       0.95      0.95      0.95      6058
           2       0.82      0.84      0.83      5911
           3       0.86      0.84      0.85      6075

    accuracy                           0.87     24000
   macro avg       0.87      0.87      0.87     24000
weighted avg       0.87      0.87      0.87     24000



In [77]:
val_predictions = loaded_model.predict(tf_idf)
val_accuracy = np.mean(val_predictions == np.argmax(y, axis=1))
print(f"Validation Accuracy: {val_accuracy:.4f}")

Validation Accuracy: 0.8043


In [80]:
f1 = f1_score(np.argmax(y, axis=1), val_predictions)
print(f"F1 Score: {f1:.4f}")

F1 Score: 0.8032


In [81]:
print("Classification Report:")
print(classification_report(np.argmax(y, axis=1), val_predictions))

Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.72      0.79      1900
           1       0.83      0.94      0.88      1900
           2       0.72      0.81      0.76      1900
           3       0.83      0.75      0.79      1900

    accuracy                           0.80      7600
   macro avg       0.81      0.80      0.80      7600
weighted avg       0.81      0.80      0.80      7600

