<a href="https://colab.research.google.com/github/ShannonBonilla/COMM557_Project/blob/main/tiktok_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import files
uploaded = files.upload()


!pip install networkx matplotlib pandas --quiet
import pandas as pd
import networkx as nx
from collections import Counter, defaultdict
from itertools import combinations
import matplotlib.pyplot as plt
%matplotlib inline

# Merge BERTopic results with original DataFrame
df = pd.read_csv("final_songs_with_lyrics.csv")

# Keep only rows with lyrics
df_with_lyrics = df.dropna(subset=['lyrics']).copy()
df_with_lyrics['lyrics_missing'] = df_with_lyrics['lyrics'].isnull()

# Ensure topics match
assert len(df_with_lyrics) == len(topics), "Mismatch: topics count doesn't match lyrics rows"

# Add topic IDs
df_with_lyrics['topic'] = topics

# Optionally keep topic names
try:
    topic_names = {t[0]: t[1] for t in topic_model.get_topic_info()[['Topic', 'Name']].values}
    df_with_lyrics['topic_name'] = df_with_lyrics['topic'].map(topic_names)
except Exception as e:
    print("Warning: could not map topic names:", e)

df_with_lyrics.to_csv("dataset_with_topics.csv", index=False)
print("Saved dataset_with_topics.csv with topics column")
print(df_with_lyrics[['track_name', 'artist_name', 'source', 'topic', 'topic_name']].head())

# Load data and filter TikTok songs
df = pd.read_csv("dataset_with_topics.csv")
tiktok = df.query("source == 'tiktok' & lyrics_missing == False").copy()

tiktok['song_id'] = tiktok.reset_index().index.astype(str)
tiktok['topic_id'] = 'T' + tiktok['topic'].astype(str)

print(f"Total TikTok songs: {len(tiktok)}")
print(f"Unique topics: {tiktok['topic'].nunique()}")

# Create bipartite graph
bipartite_edges = tiktok[['song_id', 'topic_id']].drop_duplicates()
B = nx.Graph()

songs = tiktok['song_id'].unique()
topics = tiktok['topic_id'].unique()
B.add_nodes_from(songs, bipartite=0, node_type='song')
B.add_nodes_from(topics, bipartite=1, node_type='topic')
B.add_edges_from(bipartite_edges.itertuples(index=False))

# Add song attributes to nodes
song_attrs = tiktok.set_index('song_id')[
    ['track_name', 'artist_name', 'danceability', 'energy', 'loudness', 'tempo']
].to_dict('index')
nx.set_node_attributes(B, song_attrs)

nx.write_gexf(B, "tiktok_song_topic_bipartite.gexf")
print(f"\nBipartite graph: {B.number_of_nodes()} nodes, {B.number_of_edges()} edges")

# Create weighted song-song network
song2topics = tiktok.groupby('song_id')['topic_id'].apply(set).to_dict()
edges = []

for s1, s2 in combinations(song2topics.keys(), 2):
    common_topics = song2topics[s1] & song2topics[s2]
    if common_topics:
        edges.append((s1, s2, {'weight': len(common_topics), 'shared_topics': list(common_topics)}))

G = nx.Graph()
G.add_edges_from(edges)
G.remove_nodes_from(list(nx.isolates(G)))

# Add node attributes
for song_id in G.nodes():
    if song_id in song_attrs:
        for key, value in song_attrs[song_id].items():
            G.nodes[song_id][key] = value

print(f"\nSong network: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
print(f"Average degree: {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}")
print(f"Density: {nx.density(G):.4f}")

# Community detection
from networkx.algorithms import community

communities = community.greedy_modularity_communities(G, weight='weight')
node_to_community = {}
for i, comm in enumerate(communities):
    for node in comm:
        G.nodes[node]['community'] = i
        node_to_community[node] = i

print(f"\nDetected {len(communities)} communities")
print(f"Community sizes: {sorted([len(c) for c in communities], reverse=True)}")

# Analyze each community
community_analysis = []
for i, comm in enumerate(communities):
    comm_songs = list(comm)
    comm_topics = [t for song in comm_songs for t in song2topics.get(song, set())]
    topic_counts = Counter(comm_topics)

    audio_features = defaultdict(list)
    for song in comm_songs:
        if song in song_attrs:
            for feature in ['danceability', 'energy', 'loudness', 'tempo']:
                if feature in song_attrs[song]:
                    audio_features[feature].append(song_attrs[song][feature])

    avg_audio = {f'avg_{k}': sum(v)/len(v) for k, v in audio_features.items() if v}
    community_analysis.append({
        'community_id': i,
        'size': len(comm),
        'num_unique_topics': len(topic_counts),
        'top_topics': topic_counts.most_common(3),
        **avg_audio
    })

community_df = pd.DataFrame(community_analysis)
print("\n=== Community Analysis ===")
print(community_df.to_string())

# [Optional][could finish after Midterm]Visualization
plt.figure(figsize=(15, 15))
pos = nx.spring_layout(G, weight='weight', k=1, iterations=50, seed=42)
node_colors = [node_to_community.get(node, -1) for node in G.nodes()]
edge_widths = [G[u][v]['weight'] * 0.5 for u, v in G.edges()]

nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=200, cmap=plt.cm.tab20, alpha=0.8)
nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.3)

plt.title("TikTok Song Network: Weighted by Shared Topics, Colored by Community", fontsize=16, fontweight='bold')
plt.axis('off')
plt.tight_layout()
plt.savefig('tiktok_weighted_network.png', dpi=300, bbox_inches='tight')
plt.show()

# Save output
nx.write_gexf(G, "tiktok_weighted_song_network.gexf")

community_mapping = pd.DataFrame([
    {'song_id': node,
     'track_name': G.nodes[node].get('track_name', ''),
     'artist_name': G.nodes[node].get('artist_name', ''),
     'community': G.nodes[node]['community'],
     'degree': G.degree(node),
     'weighted_degree': G.degree(node, weight='weight')}
    for node in G.nodes()
])
community_mapping.to_csv('tiktok_community_assignments.csv', index=False)

edge_list = pd.DataFrame([
    {'source': u, 'target': v,
     'weight': G[u][v]['weight'],
     'shared_topics': ','.join(G[u][v]['shared_topics'])}
    for u, v in G.edges()
])
edge_list.to_csv('tiktok_weighted_edges.csv', index=False)

print("\n All files saved!")

#[optional]Extra bipartite graph analysis
topic_degrees = {node: B.degree(node) for node in B.nodes() if B.nodes[node]['node_type'] == 'topic'}
top_topics = sorted(topic_degrees.items(), key=lambda x: x[1], reverse=True)[:10]

print("\n=== Top Topics (connected to most songs) ===")
for topic, degree in top_topics:
    print(f"{topic}: {degree} songs")
