# 第8章: ニューラルネット

第7章で取り組んだポジネガ分類を題材として、ニューラルネットワークで分類モデルを実装する。なお、この章ではPyTorchやTensorFlow、JAXなどの深層学習フレームワークを活用せよ。

## 70. 単語埋め込みの読み込み

事前学習済み単語埋め込みを活用し、$|V| \times d_\rm{emb}$ の単語埋め込み行列$\pmb{E}$を作成せよ。ここで、$|V|$は単語埋め込みの語彙数、$d_\rm{emb}$は単語埋め込みの次元数である。ただし、単語埋め込み行列の先頭の行ベクトル$\pmb{E}_{0,:}$は、将来的にパディング（`<PAD>`）トークンの埋め込みベクトルとして用いたいので、ゼロベクトルとして予約せよ。ゆえに、$\pmb{E}$の2行目以降に事前学習済み単語埋め込みを読み込むことになる。

もし、Google Newsデータセットの[学習済み単語ベクトル](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing)（300万単語・フレーズ、300次元）を全て読み込んだ場合、$|V|=3000001, d_\rm{emb}=300$になるはずである（ただ、300万単語の中には、殆ど用いられない稀な単語も含まれるので、語彙を削減した方がメモリの節約になる）。

また、単語埋め込み行列の構築と同時に、単語埋め込み行列の各行のインデックス番号（トークンID）と、単語（トークン）への双方向の対応付けを保持せよ。

In [1]:
!uv pip install -r requirements.txt
# numpyのバージョンが変わったらカーネルの再起動が必要らしい

[2mUsing Python 3.11.12 environment at: /home/tosshy/workspace/2025/.venv[0m
[2mAudited [1m141 packages[0m [2min 34ms[0m[0m


In [1]:
import gensim.downloader as api
import numpy as np
import jax
import jax.numpy as jnp

word_vectors = api.load('word2vec-google-news-300')

In [2]:
jax.config.update('jax_platform_name', 'cpu')

In [3]:
import jax
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())
print("Default device:", jax.default_backend())

JAX version: 0.6.0
Available devices: [CpuDevice(id=0)]
Default device: cpu


In [41]:
embedding_dim = word_vectors.vector_size

PAD_TOKEN = '<pad>'
PAD_ID = 0

word_to_id = {}
id_to_word = {}

word_to_id[PAD_TOKEN] = PAD_ID
id_to_word[PAD_ID] = PAD_TOKEN

for i, word in enumerate(word_vectors.index_to_key):
    current_id = i + 1
    word_to_id[word] = current_id
    id_to_word[i] = word

vocab_size = len(word_to_id)

In [42]:
import numpy as np
from tqdm import tqdm

np_embedding_matrix = np.zeros((vocab_size, embedding_dim), dtype=np.float16)

print(f"vocab size: {vocab_size}, embedding dim: {embedding_dim}")
for word, word_id in tqdm(word_to_id.items(), desc="Creating embedding matrix"):
    if word == PAD_TOKEN:
        continue
    np_embedding_matrix[word_id] = word_vectors[word]

embedding_matrix = jax.device_put(np_embedding_matrix)

# GPUデバイスに配置されているか確認
print(f"Device: {embedding_matrix.device}")


vocab size: 3000001, embedding dim: 300


Creating embedding matrix: 100%|█████████████████████████████████████████████████████████████| 3000001/3000001 [00:08<00:00, 374220.00it/s]

Device: TFRT_CPU_0





In [None]:
embedding_matrix = jax.device_put(embedding_matrix, jax.devices('cpu')[0])
print(f"Device: {embedding_matrix.device}")

Device: TFRT_CPU_0


In [43]:
print(embedding_matrix.nbytes)

1800000600


In [8]:
example = 'king'
if example in word_to_id:
    example_id = word_to_id[example]
    print(f'word: {example} ID: {example_id}')
    print(f'word: {example} JAX vector: {embedding_matrix[example_id]}')

word: king ID: 6148
word: king JAX vector: [ 1.2598e-01  2.9785e-02  8.6060e-03  1.3965e-01 -2.5635e-02 -3.6133e-02
  1.1182e-01 -1.9824e-01  5.1270e-02  3.6328e-01 -2.4219e-01 -3.0273e-01
 -1.7773e-01 -2.4902e-02 -1.6797e-01 -1.6992e-01  3.4668e-02  5.2185e-03
  4.6387e-02  1.2891e-01  1.3672e-01  1.1279e-01  5.9570e-02  1.3672e-01
  1.0107e-01 -1.7676e-01 -2.5195e-01  5.9814e-02  3.4180e-01 -3.1128e-02
  1.0449e-01  6.1768e-02  1.2451e-01  4.0039e-01 -3.2227e-01  8.3984e-02
  3.9062e-02  5.8594e-03  7.0312e-02  1.7285e-01  1.3867e-01 -2.3145e-01
  2.8320e-01  1.4258e-01  3.4180e-01 -2.3926e-02 -1.0986e-01  3.3203e-02
 -5.4688e-02  1.5320e-02 -1.6211e-01  1.5820e-01 -2.5977e-01  2.0142e-02
 -1.6309e-01  1.3580e-03 -1.4453e-01 -5.6885e-02  4.2969e-02 -2.4658e-02
  1.8555e-01  4.4727e-01  9.5825e-03  1.3184e-01  9.8633e-02 -1.8555e-01
 -1.0010e-01 -1.3379e-01 -1.2500e-01  2.8320e-01  1.2305e-01  5.3223e-02
 -1.7773e-01  8.5938e-02 -2.1851e-02  2.0508e-02 -1.3965e-01  2.5146e-02
  1.3867

## 71. データセットの読み込み

[General Language Understanding Evaluation (GLUE)](https://gluebenchmark.com/) ベンチマークで配布されている[Stanford Sentiment Treebank (SST)](https://dl.fbaipublicfiles.com/glue/data/SST-2.zip) をダウンロードし、訓練セット（train.tsv）と開発セット（dev.tsv）のテキストと極性ラベルと読み込み、全てのテキストをトークンID列に変換せよ。このとき、単語埋め込みの語彙でカバーされていない単語は無視し、トークン列に含めないことにせよ。また、テキストの全トークンが単語埋め込みの語彙に含まれておらず、空のトークン列となってしまう事例は、訓練セットおよび開発セットから削除せよ（このため、第7章の実験で得られた正解率と比較できなくなることに注意せよ）。

事例の表現方法は任意でよいが、例えば"contains no wit , only labored gags"がネガティブに分類される事例は、次のような辞書オブジェクトで表現すればよい。

```
{'text': 'contains no wit , only labored gags',
 'label': tensor([0.]),
 'input_ids': tensor([ 3475,    87, 15888,    90, 27695, 42637])}
```

この例では、`text`はテキスト、`label`は分類ラベル（ポジティブなら`tensor([1.])`、ネガティブなら`tensor([0.])`）、`input_ids`はテキストのトークン列をID列で表現している。

In [9]:
import pandas as pd

train_path = './data/SST-2/train.tsv'
dev_path = './data/SST-2/dev.tsv'

train_df = pd.read_csv(train_path, sep='\t')
dev_df = pd.read_csv(dev_path, sep='\t')

train_df

Unnamed: 0,sentence,label
0,hide new secretions from the parental units,0
1,"contains no wit , only labored gags",0
2,that loves its characters and communicates som...,1
3,remains utterly satisfied to remain the same t...,0
4,on the worst revenge-of-the-nerds clichés the ...,0
...,...,...
67344,a delightful comedy,1
67345,"anguish , anger and frustration",0
67346,"at achieving the modest , crowd-pleasing goals...",1
67347,a patient viewer,1


In [44]:
import jax.numpy as jnp

def text_to_token_ids(text, word_to_id):
    words = text.lower().split()

    token_ids = [word_to_id[word] for word in words if word in word_to_id]
    return token_ids

train_data = []

for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Processing train data"):
    text = row['sentence']
    label = jnp.array(row['label'], dtype=jnp.int8)

    token_ids = text_to_token_ids(text, word_to_id)

    if token_ids:
        train_data.append({
            'text': text,
            'label': label,
            'input_ids': token_ids
        })

dev_data = []

for _, row in tqdm(dev_df.iterrows(), total=len(dev_df), desc="Processing dev data"):
    text = row['sentence']
    label = jnp.array(row['label'], dtype=jnp.int8)

    token_ids = text_to_token_ids(text, word_to_id)

    if token_ids:
        dev_data.append({
            'text': text,
            'label': label,
            'input_ids': token_ids
        })

Processing train data:   1%|▌                                                                        | 546/67349 [00:00<00:12, 5454.44it/s]

Processing train data: 100%|███████████████████████████████████████████████████████████████████████| 67349/67349 [00:07<00:00, 9126.79it/s]
Processing dev data: 100%|█████████████████████████████████████████████████████████████████████████████| 872/872 [00:00<00:00, 8627.47it/s]


In [46]:
import sys
print(sys.getsizeof(train_data))

562488


In [13]:
train_data[0]

{'text': 'hide new secretions from the parental units ',
 'label': Array(0, dtype=int8),
 'input_ids': [5785, 66, 113845, 18, 12, 15095, 1594]}

## 72. Bag of wordsモデルの構築

単語埋め込みの平均ベクトルでテキストの特徴ベクトルを表現し、重みベクトルとの内積でポジティブ及びネガティブを分類するニューラルネットワーク（ロジスティック回帰モデル）を設計せよ。

In [14]:
import jax
import jax.numpy as jnp
from tqdm import tqdm

def create_features_labels(data):
    features = []
    labels = []
    for sample in tqdm(data, desc='Creating features and labels'):
        input_ids = sample['input_ids']
        label = sample['label']

        input_ids_array = jnp.array(input_ids)

        token_embbedings = embedding_matrix[input_ids_array]

        sentence_feature = token_embbedings.mean(axis=0)

        features.append(sentence_feature)
        labels.append(label)

    features = jnp.array(features)
    labels = jnp.array(labels)

    return features, labels


In [15]:
train_features, train_labels = create_features_labels(train_data)
dev_features, dev_labels = create_features_labels(dev_data)

Creating features and labels: 100%|████████████████████████████████████████████████████████████████| 66650/66650 [00:45<00:00, 1466.20it/s]
Creating features and labels: 100%|████████████████████████████████████████████████████████████████████| 872/872 [00:00<00:00, 1363.00it/s]


In [16]:
print(train_features.shape)
print(train_labels.shape)
print(dev_features.shape)
print(dev_labels.shape)

(66650, 300)
(66650,)
(872, 300)
(872,)


In [17]:
import flax.linen as nn

class LogisticRegression(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=1)(x)
        return jax.nn.sigmoid(x.squeeze())

## 73. モデルの学習

問題72で設計したモデルの重みベクトルを訓練セット上で学習せよ。ただし、学習中は単語埋め込み行列の値を固定せよ（単語埋め込み行列のファインチューニングは行わない）。また、学習時に損失値を表示するなど、学習の進捗状況をモニタリングできるようにせよ。

In [18]:
model = LogisticRegression()
key = jax.random.PRNGKey(0) # 乱数シードもしっかりやっておく
dummy_x = train_features[0:1]

params = model.init(key, dummy_x)['params']
print(jax.tree_util.tree_map(lambda x: x.shape, params))

{'Dense_0': {'bias': (1,), 'kernel': (300, 1)}}


In [19]:
def loss_fn(params, x_batch, y_batch):
    predictions = model.apply({'params': params}, x_batch)
    predictions = jnp.clip(predictions, 1e-7, 1 - 1e-7)
    log_likelihood = y_batch * jnp.log(predictions) + (1 - y_batch) * jnp.log(1 - predictions)
    return -jnp.mean(log_likelihood)

In [20]:
import optax

@jax.jit
def train_step(params, opt_state, x_batch, y_batch):
    loss_value, grads = jax.value_and_grad(lambda p: loss_fn(p, x_batch, y_batch))(params)

    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss_value

In [21]:
lr = 0.1
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)

n_epochs = 100

train_labels = train_labels.astype(jnp.float32)

for epoch in range(1, n_epochs+1):
    params, opt_state, current_loss = train_step(params, opt_state, train_features, train_labels)

    if epoch % 10 == 0:
        print(f'Epoch: {epoch}, Loss: {current_loss}')

Epoch: 10, Loss: 0.4317786693572998
Epoch: 20, Loss: 0.3934055268764496
Epoch: 30, Loss: 0.38222143054008484
Epoch: 40, Loss: 0.3761337399482727
Epoch: 50, Loss: 0.3727072477340698
Epoch: 60, Loss: 0.3706419765949249
Epoch: 70, Loss: 0.3692324161529541
Epoch: 80, Loss: 0.3682350218296051
Epoch: 90, Loss: 0.36749252676963806
Epoch: 100, Loss: 0.36693626642227173


## 74. モデルの評価

問題73で学習したモデルの開発セットにおける正解率を求めよ。

In [22]:
predictions = model.apply({'params': params}, dev_features)
print(predictions.shape)
print(dev_labels.shape)

(872,)
(872,)


In [23]:
pred_labels = (predictions > 0.5).astype(jnp.int32)
true_labels = dev_labels

In [24]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(pred_labels, true_labels)
print('accuracy:', accuracy)

accuracy: 0.801605504587156


## 75. パディング

複数の事例が与えられたとき、これらをまとめて一つのテンソル・オブジェクトで表現する関数`collate`を実装せよ。与えられた複数の事例のトークン列の長さが異なるときは、トークン列の長さが最も長いものに揃え、0番のトークンIDでパディングをせよ。さらに、トークン列の長さが長いものから順に、事例を並び替えよ。

例えば、訓練データセットの冒頭の4事例が次のように表されているとき、

```
[{'text': 'hide new secretions from the parental units',
  'label': tensor([0.]),
  'input_ids': tensor([  5785,     66, 113845,     18,     12,  15095,   1594])},
 {'text': 'contains no wit , only labored gags',
  'label': tensor([0.]),
  'input_ids': tensor([ 3475,    87, 15888,    90, 27695, 42637])},
 {'text': 'that loves its characters and communicates something rather beautiful about human nature',
  'label': tensor([1.]),
  'input_ids': tensor([    4,  5053,    45,  3305, 31647,   348,   904,  2815,    47,  1276,  1964])},
 {'text': 'remains utterly satisfied to remain the same throughout',
  'label': tensor([0.]),
  'input_ids': tensor([  987, 14528,  4941,   873,    12,   208,   898])}]
```

`collate`関数を通した結果は以下のようになることが想定される。

```
{'input_ids': tensor([
    [     4,   5053,     45,   3305,  31647,    348,    904,   2815,     47,   1276,   1964],
    [  5785,     66, 113845,     18,     12,  15095,   1594,      0,      0,      0,      0],
    [   987,  14528,   4941,    873,     12,    208,    898,      0,      0,      0,      0],
    [  3475,     87,  15888,     90,  27695,  42637,      0,      0,      0,      0,      0]]),
 'label': tensor([
    [1.],
    [0.],
    [0.],
    [0.]])}
```


In [25]:
def collate(data, max_length, pad_id=PAD_ID):
    sorted_data = sorted(data, key=lambda sample: len(sample['text']), reverse=True)

    collated_data = []

    for sample in tqdm(sorted_data):
        text = sample['text']
        input_ids = sample['input_ids']
        label = sample['label']

        current_length = len(input_ids)

        if current_length < max_length:
            num_padding = max_length - current_length
            padded_ids = input_ids + [pad_id] * num_padding
        else:
            padded_ids = input_ids[:max_length]
        
        padded_ids = jnp.array(padded_ids, dtype=jnp.int32)

        collated_data.append({
            'text': text,
            'input_ids': padded_ids,
            'label': label
        })
    
    return collated_data

In [26]:
collated_train_data = collate(train_data, max_length=270)
collated_dev_data = collate(dev_data, max_length=270)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 66650/66650 [00:08<00:00, 7583.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 872/872 [00:00<00:00, 7786.68it/s]


## 76. ミニバッチ学習

問題75のパディングの処理を活用して、ミニバッチでモデルを学習せよ。また、学習したモデルの開発セットにおける正解率を求めよ。

In [27]:
train_features, train_labels = create_features_labels(collated_train_data)
dev_features, dev_labels = create_features_labels(collated_dev_data)

Creating features and labels:   0%|                                                                              | 0/66650 [00:00<?, ?it/s]

Creating features and labels: 100%|████████████████████████████████████████████████████████████████| 66650/66650 [00:43<00:00, 1531.21it/s]
Creating features and labels: 100%|████████████████████████████████████████████████████████████████████| 872/872 [00:00<00:00, 1484.48it/s]


In [28]:
print(train_features.shape)
print(train_labels.shape)
print(dev_features.shape)
print(dev_labels.shape)

(66650, 300)
(66650,)
(872, 300)
(872,)


In [29]:
import flax.linen as nn

class LogisticRegression(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=1)(x)
        return jax.nn.sigmoid(x.squeeze())

In [30]:
model = LogisticRegression()
key = jax.random.PRNGKey(0) # 乱数シードもしっかりやっておく
dummy_x = train_features[0:1]

params = model.init(key, dummy_x)['params']
print(jax.tree_util.tree_map(lambda x: x.shape, params))

{'Dense_0': {'bias': (1,), 'kernel': (300, 1)}}


In [31]:
from flax.training import train_state
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np

class TrainState(train_state.TrainState):
    pass

tx = optax.adam(learning_rate=0.1)
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx,
)

def data_loader(features, labels, batch_size, rng):
    N = features.shape[0]
    perm = random.permutation(rng, N)
    perm = perm[:(N // batch_size) * batch_size].reshape(-1, batch_size)
    for idx in perm:
        yield features[idx], labels[idx]


In [45]:
@jax.jit
def train_step(state, x, y):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, x)
        preds = jnp.clip(preds, 1e-7, 1 - 1e-7)
        return -jnp.mean(y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds))
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [46]:
batch_size = 32
n_epochs = 10
rng = random.PRNGKey(0)
train_labels = train_labels.astype(jnp.float32)

for epoch in range(1, n_epochs + 1):
    rng, input_rng = random.split(rng)
    losses = []
    for x_batch, y_batch in data_loader(train_features, train_labels, batch_size, input_rng):
        state, loss = train_step(state, x_batch, y_batch)
        losses.append(loss)
    print(f'Epoch {epoch}, Loss {np.mean(jax.device_get(losses)):.4f}')

Epoch 1, Loss 0.4823
Epoch 2, Loss 0.4088
Epoch 3, Loss 0.3963
Epoch 4, Loss 0.3908
Epoch 5, Loss 0.3881
Epoch 6, Loss 0.3860
Epoch 7, Loss 0.3849
Epoch 8, Loss 0.3838
Epoch 9, Loss 0.3832
Epoch 10, Loss 0.3824


In [47]:
probs = state.apply_fn({'params': state.params}, dev_features)

pred_labels = (probs > 0.5).astype(jnp.int32)

accuracy = accuracy_score(np.array(dev_labels), np.array(pred_labels))
print('accuracy:', accuracy)

accuracy: 0.8073394495412844


## 77. GPU上での学習

問題76のモデル学習をGPU上で実行せよ。また、学習したモデルの開発セットにおける正解率を求めよ。

In [None]:
# 76で実施済み

## 78. 単語埋め込みのファインチューニング

問題77の学習において、単語埋め込みのパラメータも同時に更新するファインチューニングを導入せよ。また、学習したモデルの開発セットにおける正解率を求めよ。

In [11]:
import flax.linen as nn
import jax.numpy as jnp
import numpy as np

init_embed = nn.initializers.constant(np_embedding_matrix)

class FineTuneModel(nn.Module):
    vocab_size: int
    embed_dim: int

    @nn.compact
    def __call__(self, x):
        x_emb = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.embed_dim,
            embedding_init=init_embed
        )(x)
    
        x_feat = jnp.mean(x_emb, axis=1)

        logits = nn.Dense(features=1)(x_feat)
        return jax.nn.sigmoid(logits.squeeze())

In [12]:
print('Available devices:', jax.devices())

gpu_id = 1
device = jax.devices('gpu')[gpu_id]
print(f'Using device: {device}')

Available devices: [CudaDevice(id=0), CudaDevice(id=1)]
Using device: cuda:1


In [13]:
from flax.training import train_state
import optax

class TrainState(train_state.TrainState):
    pass

model = FineTuneModel(vocab_size=vocab_size, embed_dim=embedding_dim)
key = jax.random.PRNGKey(0)

dummy_x = collated_train_data[0]['input_ids'][None, :]
params = model.init(key, dummy_x)['params']

tx = optax.adam(learning_rate=0.1)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)


In [14]:
train_x = jnp.stack([d['input_ids'] for d in collated_train_data])
train_y = jnp.stack([d['label'] for d in collated_train_data]).astype(jnp.float32).squeeze()


dev_x = jnp.stack([d['input_ids'] for d in collated_dev_data])
dev_y = jnp.stack([d['label'] for d in collated_dev_data]).astype(jnp.float32).squeeze()

In [16]:
print(train_x.shape)
print(train_y.shape)

(66650, 270)
(66650,)


In [15]:
def data_loader(inputs: jnp.ndarray, labels: jnp.ndarray, batch_size: int, rng):
    N = inputs.shape[0]
    perm = random.permutation(rng, N)
    n_samples = (N // batch_size) * batch_size
    perm = perm[:n_samples].reshape(-1, batch_size)
    for idx in perm:
        yield inputs[idx], labels[idx]

In [18]:
@jax.jit
def train_step(state, x_batch, y_batch):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, x_batch)
        preds = jnp.clip(preds, 1e-7, 1 - 1e-7)
        return -jnp.mean(y_batch * jnp.log(preds) + (1-y_batch) * jnp.log(1-preds))
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [18]:
import functools

import jax.pmap as pmap

n_devices = len(jax.devices('gpu'))
print(f"Training on {n_devices} GPUs")

pmap_keys = jax.random.split(jax.random.PRNGKey(0), n_devices)

@functools.partial(pmap, axis_name='devices')
def train_step_pmap(state, x_batch, y_batch):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, x_batch)
        preds = jnp.clip(preds, 1e-7, 1 - 1e-7)
        return -jnp.mean(y_batch * jnp.log(preds) + (1-y_batch) * jnp.log(1-preds))
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    # デバイス間で勾配を平均化
    grads = jax.lax.pmean(grads, axis_name='devices')
    state = state.apply_gradients(grads=grads)
    return state, loss

ModuleNotFoundError: No module named 'jax.pmap'

In [None]:
import jax.random as random
import gc
batch_size = 1
n_epochs = 5
rng = random.PRNGKey(0)

with jax.default_device(device):
    for epoch in range(1, n_epochs+1):
        rng, key = random.split(rng)
        losses = []

        gc.collect()
        for x_batch, y_batch in data_loader(train_x, train_y, batch_size, key):
            state, loss = train_step_pmap(state, x_batch, y_batch)
            losses.append(loss)
        print(f"[finetune] Epoch {epoch}, loss = {jnp.mean(jnp.stack(losses)):.4f}")

        gc.collect()

2025-05-19 17:39:47.791763: E external/xla/xla/service/gpu/gpu_hlo_schedule.cc:652] The byte size of input/output arguments (21600015624) exceeds the base limit (15763193856). This indicates an error in the calculation!
2025-05-19 17:39:47.796768: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 0B (0 bytes) by rematerialization; only reduced to 16.76GiB (18000009732 bytes), down from 16.76GiB (18000009732 bytes) originally
2025-05-19 17:39:58.219366: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_1_bfc) ran out of memory trying to allocate 3.35GiB (rounded to 3600001280)requested by op 
2025-05-19 17:39:58.219725: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ********************************************************************************************________
E0519 17:39:58.219766   61510 pjrt_stream_executor_client.cc:2839] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while tr

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3600001200 bytes.

## 79. アーキテクチャの変更

ニューラルネットワークのアーキテクチャを自由に変更し、モデルを学習せよ。また、学習したモデルの開発セットにおける正解率を求めよ。例えば、テキストの特徴ベクトル（単語埋め込みの平均ベクトル）に対して多層のニューラルネットワークを通したり、畳み込みニューラルネットワーク（CNN; Convolutional Neural Network）や再帰型ニューラルネットワーク（RNN; Recurrent Neural Network）などのモデルの学習に挑戦するとよい。