In [1]:
import torch
from torch import nn
from torch.nn import functional as F

from gsm import GSM

In [2]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from torch.utils.data import Dataset, DataLoader

# 1. データのロードとベクトル化
newsgroups = fetch_20newsgroups(subset='all')
texts = newsgroups.data
labels = newsgroups.target

# テキストデータのベクトル化
vectorizer = CountVectorizer(max_features=1000, stop_words='english')
X = vectorizer.fit_transform(texts).toarray()
y = labels  # ラベルも用意

# 2. カスタムデータセットクラスの作成
class NewsgroupsDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# データセットの作成
dataset = NewsgroupsDataset(X, y)

# 3. DataLoaderの作成
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [3]:
model = GSM(1000, 128, 20)

In [4]:
model.learn(dataloader, epoch=100)

epoch:1, loss:669196606994.7722
epoch:2, loss:24435.58054971695
epoch:3, loss:8322.238621778786
epoch:4, loss:7799.998302966356
epoch:5, loss:7713.7074957937
epoch:6, loss:7961.374050796032
epoch:7, loss:7732.353099949658
epoch:8, loss:7687.321512624621
epoch:9, loss:7700.07757281512
epoch:10, loss:7703.989817798138
epoch:11, loss:7789.3257670253515
epoch:12, loss:7640.9875332042575
epoch:13, loss:7640.499278411269
epoch:14, loss:7870.102637529373
epoch:15, loss:7775.463030450046
epoch:16, loss:7685.3477323800325
epoch:17, loss:7799.037198789418
epoch:18, loss:8005.076414488256
epoch:19, loss:7592.69044046849
epoch:20, loss:8174.856428913772
epoch:21, loss:7674.508117079735
epoch:22, loss:7745.1491726487875
epoch:23, loss:7621.326284818351
epoch:24, loss:7684.0814577788115
epoch:25, loss:7593.285156816244
epoch:26, loss:8105.596265561879
epoch:27, loss:7625.577030979097
epoch:28, loss:7637.367488741875
epoch:29, loss:7777.698316968977
epoch:30, loss:7633.863264761865
epoch:31, loss:762

In [5]:
for i, ( batch_data, batch_lbl ) in enumerate(dataloader):
    x = batch_data.to(device="cuda")
    z, _ = model.encode(x)
    topic = F.softmax(z, dim=1)
    if i == 0:
        break

In [7]:
torch.argmax(topic, dim=1)

tensor([ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5, 14,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5, 14,  5,  5,  5,  5,  5,  5,  5,
        14,  5,  7, 14,  5,  5,  5,  5,  5,  5], device='cuda:0')