In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import re
from sklearn.metrics.pairwise import cosine_similarity

### 1. 샘플 데이터

In [2]:
# https://www.bbc.com/news/articles/crljn5046epo
corpus = """
A group of 59 white South Africans has arrived in the US, where they are to be granted refugee status. President Donald Trump has said the refugee applications for the country's Afrikaner minority had been expedited as they were victims of racial discrimination. The South African government said the group were not suffering any such persecution that would merit refugee status. The Trump administration has halted all other refugee admissions, including for applicants from warzones. Human Rights Watch described the move as a cruel racial twist, saying that thousands of people - many black and Afghan refugees - had been denied refuge in the US. The group of white South Africans, who landed at Dulles airport near Washington DC on Monday, received a warm welcome from US authorities. Some held young children and waved small American flags in the arrival area adorned with red, white and blue balloons on the walls. The processing of refugees in the US often takes months, even years, but this group has been fast tracked. UNHCR - the United Nations refugee agency - confirmed to the BBC it wasn't involved in the vetting, as is usually the case. Asked directly on Monday why the Afrikaners' refugee applications had been processed faster than other groups, Trump said a genocide was taking place and that white farmers specifically were being targeted. Farmers are being killed, they happen to be white, but whether they're white or black makes no difference to me. But South African President Cyril Ramaphosa said he told Trump during a phone call the US assessment of the situation was not true. A refugee is someone who has to leave their country out of fear of political persecution, religious persecution, or economic persecution, Ramaphosa said. And they don't fit that bill. In response to a question from the BBC at Dulles airport, Deputy Secretary of State Christopher Landau said: It is not surprising, unfortunately, that a country from which refugees come does not concede that they are refugees. The US has criticised domestic South African policy, accusing the government of seizing land from white farmers without any compensation. In January President Ramaphosa signed a controversial law allowing the government to seize privately owned land without compensation in certain circumstances, when it is deemed equitable and in the public interest. But the government says no land has yet been seized under the act. There has been frustration in South Africa over the slow pace of land reform in the three decades since the end of the racist apartheid system. While black South Africans make up more than 90% of the population, they only hold 4% of all privately owned land, according to a 2017 report. One of Trump's closest advisers, South African-born Elon Musk, has previously said there was a genocide of white people in South Africa and accused the government of passing racist ownership laws. The claims of a genocide of white people have been widely discredited. Skin colour of South African farmers makes no difference to Trump. In a statement to the BBC, Gregory Meeks, ranking Democratic member of the House Foreign Affairs Committee, said the Trump administration's refugee resettlement was not just a racist dog whistle, it's a politically motivated rewrite of history. The Episcopal Church said it would no longer work with the federal government on refugee settlement because of the preferential treatment granted for the Afrikaners. Commenting on this news on X, Vice-President JD Vance posted, Crazy. Melissa Keaney, a lawyer with the International Refugee Assistance project, told the BBC the White House's decision to fast-track the Afrikaners' arrival amounted to a lot of hypocrisy and unequal treatment. Her organisation is suing the Trump administration after it indefinitely suspended the US Refugee Admissions Program (USRAP) in January. She said that policy had left over 120,000 conditionally approved refugees in limbo. Afrikaner author Max du Preez told the BBC's Newsday radio programme that claims of persecution of white South Africans were a total absurdity and based on nothing. Figures from the South African police show that in 2024, 44 murders were recorded on farms and smaller plots of agricultural land, with eight of those killed being farmers. South Africa does not report on crime statistics broken down by race but a majority of the country's farmers are white, while other people living on farms, such as workers, are mostly black. Bilateral relations between the US and South Africa have been strained since President Trump first tasked his administration with resettling Afrikaners, a group with mostly Dutch ancestry, in the US. In March, South Africa's ambassador to the US, Ebrahim Rasool, was expelled after accusing President Trump of using white victimhood as a dog whistle, leading to the US accusing Mr Rasool of race-baiting. The US has also criticised South Africa for taking an aggressive position against Israel at the International Court of Justice (ICJ), where Pretoria has accused Prime Minister Benjamin Netanyahu's government of genocide against Palestinians - a claim the Israelis strongly reject. President Trump's openness to accepting Afrikaner refugees comes as the US has engaged in a wider crackdown on migrants and asylum seekers from other countries.
"""

In [11]:
corpus = corpus.lower()
corpus = re.sub(r'[^\w\s]', '',corpus)
corpus

'\na group of 59 white south africans has arrived in the us where they are to be granted refugee status president donald trump has said the refugee applications for the countrys afrikaner minority had been expedited as they were victims of racial discrimination the south african government said the group were not suffering any such persecution that would merit refugee status the trump administration has halted all other refugee admissions including for applicants from warzones human rights watch described the move as a cruel racial twist saying that thousands of people  many black and afghan refugees  had been denied refuge in the us the group of white south africans who landed at dulles airport near washington dc on monday received a warm welcome from us authorities some held young children and waved small american flags in the arrival area adorned with red white and blue balloons on the walls the processing of refugees in the us often takes months even years but this group has been f

In [13]:
tokens = corpus.split()
tokens

['a',
 'group',
 'of',
 '59',
 'white',
 'south',
 'africans',
 'has',
 'arrived',
 'in',
 'the',
 'us',
 'where',
 'they',
 'are',
 'to',
 'be',
 'granted',
 'refugee',
 'status',
 'president',
 'donald',
 'trump',
 'has',
 'said',
 'the',
 'refugee',
 'applications',
 'for',
 'the',
 'countrys',
 'afrikaner',
 'minority',
 'had',
 'been',
 'expedited',
 'as',
 'they',
 'were',
 'victims',
 'of',
 'racial',
 'discrimination',
 'the',
 'south',
 'african',
 'government',
 'said',
 'the',
 'group',
 'were',
 'not',
 'suffering',
 'any',
 'such',
 'persecution',
 'that',
 'would',
 'merit',
 'refugee',
 'status',
 'the',
 'trump',
 'administration',
 'has',
 'halted',
 'all',
 'other',
 'refugee',
 'admissions',
 'including',
 'for',
 'applicants',
 'from',
 'warzones',
 'human',
 'rights',
 'watch',
 'described',
 'the',
 'move',
 'as',
 'a',
 'cruel',
 'racial',
 'twist',
 'saying',
 'that',
 'thousands',
 'of',
 'people',
 'many',
 'black',
 'and',
 'afghan',
 'refugees',
 'had',
 'be

### 2. 단어 -> 인덱스 매핑

In [30]:
vocab = list(set(tokens))
vocab
vocab_size = len(vocab)
### 순서는 무작위로 생성된다.

In [19]:
word_to_idx = {w: i for i,w in enumerate(vocab)}
word_to_idx

{'state': 0,
 'jd': 1,
 'bbc': 2,
 'being': 3,
 'ranking': 4,
 'cruel': 5,
 'slow': 6,
 'seizing': 7,
 'expedited': 8,
 'foreign': 9,
 'rights': 10,
 'her': 11,
 'directly': 12,
 'absurdity': 13,
 'more': 14,
 '59': 15,
 'court': 16,
 'musk': 17,
 'laws': 18,
 'that': 19,
 'arrived': 20,
 'children': 21,
 'an': 22,
 'leading': 23,
 'interest': 24,
 'openness': 25,
 'by': 26,
 'reject': 27,
 'than': 28,
 'bilateral': 29,
 'accepting': 30,
 'persecution': 31,
 'figures': 32,
 'administration': 33,
 'cyril': 34,
 'move': 35,
 'unfortunately': 36,
 'fear': 37,
 'racist': 38,
 '2024': 39,
 'such': 40,
 'specifically': 41,
 'whether': 42,
 'recorded': 43,
 'january': 44,
 'keaney': 45,
 'a': 46,
 'status': 47,
 'dutch': 48,
 'genocide': 49,
 'watch': 50,
 'house': 51,
 'assessment': 52,
 'which': 53,
 'warzones': 54,
 'according': 55,
 'asked': 56,
 'previously': 57,
 'program': 58,
 'left': 59,
 'hold': 60,
 'welcome': 61,
 'resettling': 62,
 'as': 63,
 'targeted': 64,
 'three': 65,
 'polic

In [26]:
idx_to_word = {i:w for w,i in word_to_idx.items()}
idx_to_word

{0: 'state',
 1: 'jd',
 2: 'bbc',
 3: 'being',
 4: 'ranking',
 5: 'cruel',
 6: 'slow',
 7: 'seizing',
 8: 'expedited',
 9: 'foreign',
 10: 'rights',
 11: 'her',
 12: 'directly',
 13: 'absurdity',
 14: 'more',
 15: '59',
 16: 'court',
 17: 'musk',
 18: 'laws',
 19: 'that',
 20: 'arrived',
 21: 'children',
 22: 'an',
 23: 'leading',
 24: 'interest',
 25: 'openness',
 26: 'by',
 27: 'reject',
 28: 'than',
 29: 'bilateral',
 30: 'accepting',
 31: 'persecution',
 32: 'figures',
 33: 'administration',
 34: 'cyril',
 35: 'move',
 36: 'unfortunately',
 37: 'fear',
 38: 'racist',
 39: '2024',
 40: 'such',
 41: 'specifically',
 42: 'whether',
 43: 'recorded',
 44: 'january',
 45: 'keaney',
 46: 'a',
 47: 'status',
 48: 'dutch',
 49: 'genocide',
 50: 'watch',
 51: 'house',
 52: 'assessment',
 53: 'which',
 54: 'warzones',
 55: 'according',
 56: 'asked',
 57: 'previously',
 58: 'program',
 59: 'left',
 60: 'hold',
 61: 'welcome',
 62: 'resettling',
 63: 'as',
 64: 'targeted',
 65: 'three',
 66: 'p

In [21]:
def geneate_pairs(corpus, window_size=1):
    pairs=[]
    for i in range(len(corpus)):
        center = word_to_idx[corpus[i]]
        for j in range(i-window_size, i+window_size + 1):
            # 자기 자신과의 쌍은 피하고 j는 0에서 크기보다 작음
            if j!=i and 0 <=j <len(corpus):
                context = word_to_idx[corpus[j]]
                pairs.append(((center,context)))
    return pairs

In [24]:
pairs = geneate_pairs(tokens)
pairs
### 결과는 문장 순서 기반으로 인덱스별 쌍 (window=1이므로 양옆)

[(46, 121),
 (121, 46),
 (121, 406),
 (406, 121),
 (406, 15),
 (15, 406),
 (15, 359),
 (359, 15),
 (359, 141),
 (141, 359),
 (141, 282),
 (282, 141),
 (282, 179),
 (179, 282),
 (179, 20),
 (20, 179),
 (20, 328),
 (328, 20),
 (328, 136),
 (136, 328),
 (136, 183),
 (183, 136),
 (183, 234),
 (234, 183),
 (234, 310),
 (310, 234),
 (310, 320),
 (320, 310),
 (320, 317),
 (317, 320),
 (317, 239),
 (239, 317),
 (239, 302),
 (302, 239),
 (302, 216),
 (216, 302),
 (216, 47),
 (47, 216),
 (47, 231),
 (231, 47),
 (231, 168),
 (168, 231),
 (168, 357),
 (357, 168),
 (357, 179),
 (179, 357),
 (179, 131),
 (131, 179),
 (131, 136),
 (136, 131),
 (136, 216),
 (216, 136),
 (216, 299),
 (299, 216),
 (299, 303),
 (303, 299),
 (303, 136),
 (136, 303),
 (136, 293),
 (293, 136),
 (293, 398),
 (398, 293),
 (398, 122),
 (122, 398),
 (122, 358),
 (358, 122),
 (358, 391),
 (391, 358),
 (391, 8),
 (8, 391),
 (8, 63),
 (63, 8),
 (63, 310),
 (310, 63),
 (310, 99),
 (99, 310),
 (99, 79),
 (79, 99),
 (79, 406),
 (406,

In [25]:
# 인덱스 쌍 → 단어 쌍
for center_idx, context_idx in pairs[:10]:
    print(f"({idx_to_word[center_idx]}, {idx_to_word[context_idx]})")

(a, group)
(group, a)
(group, of)
(of, group)
(of, 59)
(59, of)
(59, white)
(white, 59)
(white, south)
(south, white)


### 3. 네거티브 샘플링을 위해 unigram을 3/4 분포로 정의

In [28]:
word_freq = np.array([corpus.count(w) for w in vocab])
word_freq
### 단어 빈도수 (문장 순서로)
### 전체 문장에 vocab을 하나씩 꺼내서 개수를 저장

array([  2,   1,   5,   3,   1,   1,   1,   1,   1,   1,   1,  11,   1,
         1,   1,   1,   1,   1,   1,   9,   1,   1,  62,   1,   1,   1,
         1,   1,   2,   1,   1,   5,   1,   4,   1,   1,   1,   1,   3,
         1,   2,   1,   1,   1,   2,   1, 383,   2,   1,   4,   1,   2,
         1,   1,   1,   1,   2,   1,   2,   1,   1,   1,   1,  40,   1,
         1,   1,   1,   1,   1,   1,   1,   1,   1,  19,   1,   1,   1,
         1,   1,   1,   1,   1,   2,   1,   2,   1,   2,   6,   1,   1,
         1,   5,   1,   1,   2,   7,   1,   3,   5,   2,  17,   1,   1,
         1,   2,   1,   1,   8,   1,   1,   2,   1,   2,   6,  83,   1,
         2,   1,   1,   2,   6,   1,   1,   1,   1,   1,   3,  21,   1,
         1,  10,   1,   1,   1,  12,  74,   1,   1,   4,   2,  16,   1,
         1,   7,   2,   2,   2,   7,   4,   1,   1,   1,   4,   1,   1,
         1,   8,   1,   2,   1,   1,   1,   1,   1,   1,   1,   1,   1,
         2,   1,   1,   1,   1,   1,   1,   2,   1,   1,  12,  4

In [29]:
word_power = word_freq ** 0.75
unigram_dist = word_power / np.sum(word_power)
unigram_dist
### 0.75만큼 적용한 것에 빈도수 분포

array([0.00170067, 0.00101122, 0.00338122, 0.00230509, 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.00101122, 0.00101122,
       0.00101122, 0.00610788, 0.00101122, 0.00101122, 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.00101122, 0.00525446,
       0.00101122, 0.00101122, 0.02234293, 0.00101122, 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.00170067, 0.00101122,
       0.00101122, 0.00338122, 0.00101122, 0.00286017, 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.00230509, 0.00101122,
       0.00170067, 0.00101122, 0.00101122, 0.00101122, 0.00170067,
       0.00101122, 0.08754783, 0.00170067, 0.00101122, 0.00286017,
       0.00101122, 0.00170067, 0.00101122, 0.00101122, 0.00101122,
       0.00101122, 0.00170067, 0.00101122, 0.00170067, 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.0160839 , 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.00101122, 0.00101122,
       0.00101122, 0.00101122, 0.00101122, 0.00101122, 0.00920

### 4. 네거티브 샘플링 함수

#### 배경
$$P(w_O \mid w_C) = \frac{\exp(\vec{v}{w_O} \cdot \vec{v}{w_C})}{\sum_{w \in V} \exp(\vec{v}w \cdot \vec{v}{w_C})}$$


$w_c$: 중심 단어
$w_o$: 주변 단어
$V$: 전체 단어 집합

기존 softmax는 계산량이 커서 속도가 느리다.

#### 네거티브 샘플링
$$\log \sigma(\vec{v}{w_O} \cdot \vec{v}{w_C}) + \sum_{k=1}^{K} \log \sigma(-\vec{v}{w_k} \cdot \vec{v}{w_C})$$

첫 항: 실제 단어 쌍에 대해 진짜(1) 라고 예측하도록 한다.
나머지 K개의 항: 네거티브 단어들에 대해 가짜(0)라고 예측하도록 한다.

진짜 주변 단어는 1 샘프링된 나머지는 0으로 분류하는 문제

In [31]:
# num_samples == K
# unigram 분포에서 단어를 랜덤하게 뽑는다
# 만약 샘플이 실제 주변단어가 아니면 추가한다
# 쌍이 여러 개라 정답은 독립적으로 정답처리
def get_negative_samples(pos_idx, num_samples=5):
    neg_samples = []
    while len(neg_samples) <num_samples:
        sample = np.random.choice(vocab_size, p=unigram_dist)
        if(sample != pos_idx):
            neg_samples.append(sample)
    return neg_samples

### 5. 모델 정의

In [45]:
class SGNS(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.center_emb = nn.Embedding(vocab_size, embedding_dim)
        self.context_emb = nn.Embedding(vocab_size,embedding_dim)

    def forward(self,center,context,negative):
        center_v = self.center_emb(center)
        context_v = self.context_emb(context)
        negative_v = self.context_emb(negative)

        # 임베딩 벡터의 곱, 실제쌍은 score가 높으면 좋고 가짜 쌍은 낮아야 좋게 설정
        # 로그여서 스코어가 높을수록 1에 가까워서 로스는 0으로 유도
        pos_score = torch.sum(center_v * context_v, dim=1)
        pos_loss = torch.log(torch.sigmoid(pos_score))

        # k개의 네거티브 샘플링은 한번에 처리하기 위해 적용
        # negative_v : [1,K,D]
        # center_v : [1,D]
        # center_v.unsqueeze(2) : [1,D,1]
        # bmm 결과 : [1,K,1]
        # squeeze(2) : [1,K]
        # 쉽게 보자면 네거티브 스코어는 주변 단어 스코어와는 반대되는 부분을 수치화하고 0 score가 낮게 학습되는 것
        neg_score = torch.bmm(negative_v, center_v.unsqueeze(2)).squeeze(2)
        # 얘는 스코어가 낮을수록 (-) 가 붙어서  loss가 작아짐
        neg_loss = torch.sum(torch.log(torch.sigmoid(-neg_score)),dim=1)
        return -torch.mean(pos_loss + neg_loss)


### 6. 학습

In [48]:
embedding_dim = 10
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = SGNS(vocab_size,embedding_dim).to(device)
optimizer = optim.Adam(model.parameters(),lr=0.01)
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())

True
True


In [49]:
for epoch in range(100):
    total_loss =0
    for center, context in pairs:
        neg_samples = get_negative_samples(context)
        # 중심단어 input
        center_tensor = torch.LongTensor([center]).to(device)
        # 주변 단어 진짜 정답 레이블
        context_tensor = torch.LongTensor([context]).to(device)
        # 주변단어 가짜 정답 레이블
        neg_tensor = torch.LongTensor([neg_samples]).to(device)

        loss = model(center_tensor, context_tensor,neg_tensor)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}")

Epoch 0, Loss: 11643.9022
Epoch 10, Loss: 2583.6119
Epoch 20, Loss: 2225.7217
Epoch 30, Loss: 1984.0574
Epoch 40, Loss: 2003.8076
Epoch 50, Loss: 2014.6086
Epoch 60, Loss: 2052.8716
Epoch 70, Loss: 2023.6410
Epoch 80, Loss: 2138.8216
Epoch 90, Loss: 1919.6189


In [51]:
for i, word in idx_to_word.items():
    vec = model.center_emb.weight[i].detach().cpu().numpy()
    print(f"{word}: {vec}")

state: [ 3.5921488  -2.8894877   0.87167406  1.0098456  -0.69468707 -6.4937983
  0.34394     0.10874554  0.33798406  4.5449386 ]
jd: [ 3.5209503 -2.2794771 -1.6680069 -1.5255908  2.14735   -1.1786442
 -0.8470788  4.361628  -2.2118232  2.2583315]
bbc: [ 2.5891347  -0.23796676 -1.3967903  -1.5145919   1.3271029  -0.72262084
 -0.9107621  -1.1610028   1.0642071   1.77469   ]
being: [ 1.8093702  -4.512679    3.9652648  -2.1945667  -1.1614519   0.7792116
 -0.39577535 -0.93489     1.4587334   0.7231353 ]
ranking: [-0.24232681 -1.3755316   4.1942396  -2.5826237  -2.5301497  -1.403187
  1.5335628   5.5019045  -0.41942227  2.209379  ]
cruel: [ 2.6191168   0.12139864 -0.9738847  -5.7057424  -0.8241071  -2.5614543
  1.6375846   1.2215669   0.37891468 -0.5857021 ]
slow: [ 2.9455156   2.291538   -0.44922698 -3.125803   -0.12991828 -3.0993197
 -1.3669636   0.43925753  5.764821    0.7342547 ]
seizing: [ 3.9486728  1.8019418  2.2143362  1.6883905 -2.5529456 -6.0021305
 -2.8166552  2.7214909  3.8237774 

In [54]:
def find_similar_words(target_word, word_to_idx, idx_to_word, embedding_weight, top_k=10):
    if target_word not in word_to_idx:
        print(f"'{target_word}' not in vocabulary.")
        return

    target_idx = word_to_idx[target_word]

    # 1. 벡터 추출
    target_vec = embedding_weight[target_idx].detach().cpu().numpy().reshape(1, -1)
    all_vecs = embedding_weight.detach().cpu().numpy()

    # 2. 코사인 유사도 계산
    similarities = cosine_similarity(target_vec, all_vecs)[0]

    # 3. 유사도 정렬 (자기 자신 제외)
    similar_indices = similarities.argsort()[::-1]
    top_indices = [i for i in similar_indices if i != target_idx][:top_k]

    # 4. 결과 출력
    for i in top_indices:
        print(f"{idx_to_word[i]}: similarity = {similarities[i]:.4f}")

In [55]:
find_similar_words("trump", word_to_idx, idx_to_word, model.center_emb.weight, top_k=10)

since: similarity = 0.8306
three: similarity = 0.8114
minister: similarity = 0.7972
all: similarity = 0.7956
government: similarity = 0.7928
tasked: similarity = 0.7767
suspended: similarity = 0.7622
history: similarity = 0.7587
united: similarity = 0.7442
bbcs: similarity = 0.7431


| 단어        | 유사도   | 해석                                                                 |
|-------------|----------|----------------------------------------------------------------------|
| since       | 0.8306   | 시점 표현. 트럼프 정권 이후, 관계가 변했다는 문맥에 자주 등장         |
| three       | 0.8114   | 기사 중 “three decades since apartheid” → 문맥 연결                  |
| minister    | 0.7972   | prime minister, Netanyahu 등과의 외교 맥락                            |
| all         | 0.7956   | “Trump administration halted all…” 같은 표현                         |
| government  | 0.7928   | 당연히 자주 함께 나옴 (미국/남아공 정부 모두 관련)                    |
| tasked      | 0.7767   | “Trump tasked his administration with…” 같은 문장 존재               |
| suspended   | 0.7622   | “Trump suspended USRAP…” → 정책 관련                                  |
| history     | 0.7587   | 과거와 현재의 외교, 인종 문제 맥락에서 연관                          |
| united      | 0.7442   | 미국(United States)과 유엔(United Nations) 맥락 모두 가능             |
| bbcs        | 0.7431   | 기사에서 출처로 등장하는 BBC 인용 (예: “told the BBC’s Newsday”)     |